mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-10 18:42:17 +00:00
commit
bac4db61c3
89
.github/workflows/publish_to_pypi.yml
vendored
89
.github/workflows/publish_to_pypi.yml
vendored
@ -1,3 +1,19 @@
|
|||||||
|
# Copyright 2021 Xiaomi Corp. (author: Fangjun Kuang)
|
||||||
|
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
name: Publish to PyPI
|
name: Publish to PyPI
|
||||||
|
|
||||||
on:
|
on:
|
||||||
@ -7,30 +23,85 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pypi:
|
pypi:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-18.04]
|
||||||
|
cuda: ["10.1"]
|
||||||
|
gcc: ["5"]
|
||||||
|
torch: ["1.8.1"]
|
||||||
|
python-version: [3.6, 3.7, 3.8]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v2
|
||||||
with:
|
with:
|
||||||
python-version: 3.6
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: Install Python dependencies
|
- name: Install CUDA Toolkit ${{ matrix.cuda }}
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
cuda: ${{ matrix.cuda }}
|
||||||
|
run: |
|
||||||
|
source ./scripts/github_actions/install_cuda.sh
|
||||||
|
echo "CUDA_HOME=${CUDA_HOME}" >> $GITHUB_ENV
|
||||||
|
echo "${CUDA_HOME}/bin" >> $GITHUB_PATH
|
||||||
|
echo "LD_LIBRARY_PATH=${CUDA_HOME}/lib:${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Display NVCC version
|
||||||
|
run: |
|
||||||
|
which nvcc
|
||||||
|
nvcc --version
|
||||||
|
|
||||||
|
- name: Install GCC ${{ matrix.gcc }}
|
||||||
|
run: |
|
||||||
|
sudo apt-get install -y gcc-${{ matrix.gcc }} g++-${{ matrix.gcc }}
|
||||||
|
echo "CC=/usr/bin/gcc-${{ matrix.gcc }}" >> $GITHUB_ENV
|
||||||
|
echo "CXX=/usr/bin/g++-${{ matrix.gcc }}" >> $GITHUB_ENV
|
||||||
|
echo "CUDAHOSTCXX=/usr/bin/g++-${{ matrix.gcc }}" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Install PyTorch ${{ matrix.torch }}
|
||||||
|
env:
|
||||||
|
cuda: ${{ matrix.cuda }}
|
||||||
|
torch: ${{ matrix.torch }}
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install --upgrade pip
|
python3 -m pip install --upgrade pip
|
||||||
python3 -m pip install wheel twine setuptools
|
python3 -m pip install wheel twine typing_extensions
|
||||||
|
python3 -m pip install bs4 requests tqdm
|
||||||
|
|
||||||
- name: Build
|
./scripts/github_actions/install_torch.sh
|
||||||
shell: bash
|
python3 -c "import torch; print('torch version:', torch.__version__)"
|
||||||
|
|
||||||
|
- name: Download cudnn 8.0
|
||||||
|
env:
|
||||||
|
cuda: ${{ matrix.cuda }}
|
||||||
run: |
|
run: |
|
||||||
python3 setup.py sdist
|
./scripts/github_actions/install_cudnn.sh
|
||||||
ls -l dist/*
|
|
||||||
|
- name: Build pip packages
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
KALDIFEAT_IS_FOR_PYPI: 1
|
||||||
|
run: |
|
||||||
|
tag=$(python3 -c "import sys; print(''.join(sys.version[:3].split('.')))")
|
||||||
|
export KALDIFEAT_MAKE_ARGS="-j2"
|
||||||
|
python3 setup.py bdist_wheel --python-tag=py${tag}
|
||||||
|
ls -lh dist/
|
||||||
|
|
||||||
- name: Publish wheels to PyPI
|
- name: Publish wheels to PyPI
|
||||||
env:
|
env:
|
||||||
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
|
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
|
||||||
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
|
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
|
||||||
run: |
|
run: |
|
||||||
twine upload dist/kaldifeat-*.tar.gz
|
twine upload dist/kaldifeat-*.whl
|
||||||
|
|
||||||
|
- name: Upload Wheel
|
||||||
|
uses: actions/upload-artifact@v2
|
||||||
|
with:
|
||||||
|
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-${{ matrix.os }}
|
||||||
|
path: dist/*.whl
|
||||||
|
38
README.md
38
README.md
@ -4,20 +4,36 @@ Wrap kaldi's feature computations to Python with PyTorch support.
|
|||||||
|
|
||||||
# Installation
|
# Installation
|
||||||
|
|
||||||
`kaldifeat` can be installed by
|
## From PyPi with pip
|
||||||
|
|
||||||
|
If you install `kaldifeat` using `pip`, it will also install
|
||||||
|
PyTorch 1.8.1. If this is not what you want, please install `kaldifeat`
|
||||||
|
from source (see below).
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install kaldifeat
|
pip install kaldifeat
|
||||||
```
|
```
|
||||||
|
|
||||||
# TODOs
|
## From source
|
||||||
|
|
||||||
- [ ] Add Python interface
|
The following are the commands to compile `kaldifeat` from source.
|
||||||
- [ ] Support torch.device so that it can switch between CUDA and CPU
|
We assume that you have installed `cmake` and PyTorch.
|
||||||
- [ ] Add unit tests
|
cmake 3.11 is known to work. Other cmake versions may also work.
|
||||||
- [ ] Set up GitHub actions
|
PyTorch 1.8.1 is known to work. Other PyTorch versions may also work.
|
||||||
- [ ] Benchmark its speed and compare it with Kaldi
|
|
||||||
- [ ] Support batch processing of multiple waves
|
```bash
|
||||||
- [ ] Handle non-default parameters
|
mkdir /some/path
|
||||||
- [ ] Support MFCC and other features available in Kaldi
|
git clone https://github.com/csukuangfj/kaldifeat.git
|
||||||
- [ ] Publish it to PyPI
|
cd kaldifeat
|
||||||
|
python setup.py install
|
||||||
|
```
|
||||||
|
|
||||||
|
To test whether `kaldifeat` was installed successfully, you can run:
|
||||||
|
```
|
||||||
|
python3 -c "import kaldifeat; print(kaldifeat.__version__)"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
Please refer to <https://kaldifeat.readthedocs.io/en/latest/usage.html>
|
||||||
|
for how to use `kaldifeat`.
|
||||||
|
@ -9,6 +9,12 @@ from pathlib import Path
|
|||||||
import setuptools
|
import setuptools
|
||||||
from setuptools.command.build_ext import build_ext
|
from setuptools.command.build_ext import build_ext
|
||||||
|
|
||||||
|
|
||||||
|
def is_for_pypi():
|
||||||
|
ans = os.environ.get("KALDIFEAT_IS_FOR_PYPI", None)
|
||||||
|
return ans is not None
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
||||||
|
|
||||||
@ -17,11 +23,12 @@ try:
|
|||||||
_bdist_wheel.finalize_options(self)
|
_bdist_wheel.finalize_options(self)
|
||||||
# In this case, the generated wheel has a name in the form
|
# In this case, the generated wheel has a name in the form
|
||||||
# k2-xxx-pyxx-none-any.whl
|
# k2-xxx-pyxx-none-any.whl
|
||||||
# self.root_is_pure = True
|
if is_for_pypi():
|
||||||
|
self.root_is_pure = True
|
||||||
# The generated wheel has a name ending with
|
else:
|
||||||
# -linux_x86_64.whl
|
# The generated wheel has a name ending with
|
||||||
self.root_is_pure = False
|
# -linux_x86_64.whl
|
||||||
|
self.root_is_pure = False
|
||||||
|
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
20
doc/Makefile
Normal file
20
doc/Makefile
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# Minimal makefile for Sphinx documentation
|
||||||
|
#
|
||||||
|
|
||||||
|
# You can set these variables from the command line, and also
|
||||||
|
# from the environment for the first two.
|
||||||
|
SPHINXOPTS ?=
|
||||||
|
SPHINXBUILD ?= sphinx-build
|
||||||
|
SOURCEDIR = source
|
||||||
|
BUILDDIR = build
|
||||||
|
|
||||||
|
# Put it first so that "make" without argument is like "make help".
|
||||||
|
help:
|
||||||
|
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||||
|
|
||||||
|
.PHONY: help Makefile
|
||||||
|
|
||||||
|
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||||
|
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||||
|
%: Makefile
|
||||||
|
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
35
doc/make.bat
Normal file
35
doc/make.bat
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
@ECHO OFF
|
||||||
|
|
||||||
|
pushd %~dp0
|
||||||
|
|
||||||
|
REM Command file for Sphinx documentation
|
||||||
|
|
||||||
|
if "%SPHINXBUILD%" == "" (
|
||||||
|
set SPHINXBUILD=sphinx-build
|
||||||
|
)
|
||||||
|
set SOURCEDIR=source
|
||||||
|
set BUILDDIR=build
|
||||||
|
|
||||||
|
if "%1" == "" goto help
|
||||||
|
|
||||||
|
%SPHINXBUILD% >NUL 2>NUL
|
||||||
|
if errorlevel 9009 (
|
||||||
|
echo.
|
||||||
|
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||||
|
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||||
|
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||||
|
echo.may add the Sphinx directory to PATH.
|
||||||
|
echo.
|
||||||
|
echo.If you don't have Sphinx installed, grab it from
|
||||||
|
echo.http://sphinx-doc.org/
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||||
|
goto end
|
||||||
|
|
||||||
|
:help
|
||||||
|
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||||
|
|
||||||
|
:end
|
||||||
|
popd
|
6
doc/requirements.txt
Normal file
6
doc/requirements.txt
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
dataclasses
|
||||||
|
recommonmark
|
||||||
|
sphinx
|
||||||
|
sphinx-autodoc-typehints
|
||||||
|
sphinx_rtd_theme
|
||||||
|
sphinxcontrib-bibtex
|
72
doc/source/code/test_fbank.py
Executable file
72
doc/source/code/test_fbank.py
Executable file
@ -0,0 +1,72 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import kaldifeat
|
||||||
|
|
||||||
|
|
||||||
|
def read_wave(filename) -> torch.Tensor:
|
||||||
|
"""Read a wave file and return it as a 1-D tensor.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
You don't need to scale it to [-32768, 32767].
|
||||||
|
We use scaling here to follow the approach in Kaldi.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of a sound file.
|
||||||
|
Returns:
|
||||||
|
Return a 1-D tensor containing audio samples.
|
||||||
|
"""
|
||||||
|
with sf.SoundFile(filename) as sf_desc:
|
||||||
|
sampling_rate = sf_desc.samplerate
|
||||||
|
assert sampling_rate == 16000
|
||||||
|
data = sf_desc.read(dtype=np.float32, always_2d=False)
|
||||||
|
data *= 32768
|
||||||
|
return torch.from_numpy(data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fbank():
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
wave0 = read_wave("test_data/test.wav")
|
||||||
|
wave1 = read_wave("test_data/test2.wav")
|
||||||
|
|
||||||
|
wave0 = wave0.to(device)
|
||||||
|
wave1 = wave1.to(device)
|
||||||
|
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.device = device
|
||||||
|
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
# We can compute fbank features in batches
|
||||||
|
features = fbank([wave0, wave1])
|
||||||
|
assert isinstance(features, list), f"{type(features)}"
|
||||||
|
assert len(features) == 2
|
||||||
|
|
||||||
|
# We can also compute fbank features for a single wave
|
||||||
|
features0 = fbank(wave0)
|
||||||
|
features1 = fbank(wave1)
|
||||||
|
|
||||||
|
assert torch.allclose(features[0], features0)
|
||||||
|
assert torch.allclose(features[1], features1)
|
||||||
|
|
||||||
|
# To compute fbank features for only a specified frame
|
||||||
|
audio_frames = fbank.convert_samples_to_frames(wave0)
|
||||||
|
feature_frame_1 = fbank.compute(audio_frames[1])
|
||||||
|
feature_frame_10 = fbank.compute(audio_frames[10])
|
||||||
|
|
||||||
|
assert torch.allclose(features0[1], feature_frame_1)
|
||||||
|
assert torch.allclose(features0[10], feature_frame_10)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_fbank()
|
104
doc/source/conf.py
Normal file
104
doc/source/conf.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
# Configuration file for the Sphinx documentation builder.
|
||||||
|
#
|
||||||
|
# This file only contains a selection of the most common options. For a full
|
||||||
|
# list see the documentation:
|
||||||
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||||
|
|
||||||
|
# -- Path setup --------------------------------------------------------------
|
||||||
|
|
||||||
|
# If extensions (or modules to document with autodoc) are in another directory,
|
||||||
|
# add these directories to sys.path here. If the directory is relative to the
|
||||||
|
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||||
|
#
|
||||||
|
# import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
import sphinx_rtd_theme
|
||||||
|
|
||||||
|
# import sys
|
||||||
|
# sys.path.insert(0, os.path.abspath('.'))
|
||||||
|
|
||||||
|
|
||||||
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
|
project = "kaldifeat"
|
||||||
|
copyright = "2021, Fangjun Kuang"
|
||||||
|
author = "Fangjun Kuang"
|
||||||
|
|
||||||
|
|
||||||
|
def get_version():
|
||||||
|
cmake_file = "../../CMakeLists.txt"
|
||||||
|
with open(cmake_file) as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
version = re.search(r"set\(kaldifeat_VERSION (.*)\)", content).group(1)
|
||||||
|
return version.strip('"')
|
||||||
|
|
||||||
|
|
||||||
|
version = get_version()
|
||||||
|
release = version
|
||||||
|
|
||||||
|
|
||||||
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
|
# Add any Sphinx extension module names here, as strings. They can be
|
||||||
|
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||||
|
# ones.
|
||||||
|
extensions = [
|
||||||
|
"recommonmark",
|
||||||
|
"sphinx.ext.autodoc",
|
||||||
|
"sphinx.ext.githubpages",
|
||||||
|
"sphinx.ext.napoleon",
|
||||||
|
"sphinx_autodoc_typehints",
|
||||||
|
"sphinx_rtd_theme",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add any paths that contain templates here, relative to this directory.
|
||||||
|
templates_path = ["_templates"]
|
||||||
|
|
||||||
|
# List of patterns, relative to source directory, that match files and
|
||||||
|
# directories to ignore when looking for source files.
|
||||||
|
# This pattern also affects html_static_path and html_extra_path.
|
||||||
|
exclude_patterns = []
|
||||||
|
|
||||||
|
source_suffix = {
|
||||||
|
".rst": "restructuredtext",
|
||||||
|
".md": "markdown",
|
||||||
|
}
|
||||||
|
master_doc = "index"
|
||||||
|
|
||||||
|
|
||||||
|
# -- Options for HTML output -------------------------------------------------
|
||||||
|
|
||||||
|
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||||
|
# a list of builtin themes.
|
||||||
|
#
|
||||||
|
html_theme = "sphinx_rtd_theme"
|
||||||
|
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
|
||||||
|
html_show_sourcelink = True
|
||||||
|
|
||||||
|
# Add any paths that contain custom static files (such as style sheets) here,
|
||||||
|
# relative to this directory. They are copied after the builtin static files,
|
||||||
|
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||||
|
html_static_path = ["_static"]
|
||||||
|
|
||||||
|
pygments_style = "sphinx"
|
||||||
|
|
||||||
|
numfig = True
|
||||||
|
|
||||||
|
html_context = {
|
||||||
|
"display_github": True,
|
||||||
|
"github_user": "csukuangfj",
|
||||||
|
"github_repo": "kaldifeat",
|
||||||
|
"github_version": "master",
|
||||||
|
"conf_py_path": "/kaldifeat/docs/source/",
|
||||||
|
}
|
||||||
|
|
||||||
|
# refer to
|
||||||
|
# https://sphinx-rtd-theme.readthedocs.io/en/latest/configuring.html
|
||||||
|
html_theme_options = {
|
||||||
|
"logo_only": False,
|
||||||
|
"display_version": True,
|
||||||
|
"prev_next_buttons_location": "bottom",
|
||||||
|
"style_external_links": True,
|
||||||
|
}
|
24
doc/source/index.rst
Normal file
24
doc/source/index.rst
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
.. kaldifeat documentation master file, created by
|
||||||
|
sphinx-quickstart on Fri Jul 16 20:15:27 2021.
|
||||||
|
You can adapt this file completely to your liking, but it should at least
|
||||||
|
contain the root `toctree` directive.
|
||||||
|
|
||||||
|
kaldifeat
|
||||||
|
=========
|
||||||
|
|
||||||
|
`kaldifeat <https://github.com/csukuangfj/kaldifeat>`_ implements
|
||||||
|
feature extraction algorithms **compatible** with kaldi using PyTorch, supporting CUDA
|
||||||
|
as well as autograd.
|
||||||
|
|
||||||
|
Currently, only fbank features are supported.
|
||||||
|
It can produce the same feature output as ``compute-fbank-feats`` (from kaldi)
|
||||||
|
when given the same options.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 2
|
||||||
|
:caption: Contents:
|
||||||
|
|
||||||
|
installation
|
||||||
|
usage
|
54
doc/source/installation.rst
Normal file
54
doc/source/installation.rst
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
Installation
|
||||||
|
============
|
||||||
|
|
||||||
|
.. _from source:
|
||||||
|
|
||||||
|
Install kaldifeat from source
|
||||||
|
-----------------------------
|
||||||
|
|
||||||
|
You have to install ``cmake`` and ``PyTorch`` first.
|
||||||
|
|
||||||
|
- ``cmake`` 3.11 is known to work. Other CMake versions may also work.
|
||||||
|
- ``PyTorch`` 1.8.1 is known to work. Other PyTorch versions may also work.
|
||||||
|
- Python >= 3.6
|
||||||
|
|
||||||
|
|
||||||
|
The commands to install ``kaldifeat`` from source are:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
git clone https://github.com/csukuangfj/kaldifeat
|
||||||
|
cd kaldifeat
|
||||||
|
python3 setup.py install
|
||||||
|
|
||||||
|
To test that you have installed ``kaldifeat`` successfully, please run:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
python3 -c "import kaldifeat; print(kaldifeat.__version__)"
|
||||||
|
|
||||||
|
It should print the version, e.g., ``1.0``.
|
||||||
|
|
||||||
|
Install kaldifeat from PyPI
|
||||||
|
---------------------------
|
||||||
|
|
||||||
|
The pre-built ``kaldifeat`` hosted on PyPI uses PyTorch 1.8.1.
|
||||||
|
If you install ``kaldifeat`` using pip, it will replace your locally
|
||||||
|
installed PyTorch automatically with PyTorch 1.8.1.
|
||||||
|
|
||||||
|
If you don't want this happen, please `Install kaldifeat from source`_.
|
||||||
|
|
||||||
|
The command to install ``kaldifeat`` from PyPI is:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
pip install kaldifeat
|
||||||
|
|
||||||
|
|
||||||
|
To test that you have installed ``kaldifeat`` successfully, please run:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
python3 -c "import kaldifeat; print(kaldifeat.__version__)"
|
||||||
|
|
||||||
|
It should print the version, e.g., ``1.0``.
|
212
doc/source/usage.rst
Normal file
212
doc/source/usage.rst
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
Usage
|
||||||
|
=====
|
||||||
|
|
||||||
|
Let us first see the help message of kaldi's ``compute-fbank-feats``:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ compute-fbank-feats
|
||||||
|
|
||||||
|
Create Mel-filter bank (FBANK) feature files.
|
||||||
|
Usage: compute-fbank-feats [options...] <wav-rspecifier> <feats-wspecifier>
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--allow-downsample : If true, allow the input waveform to have a higher frequency than the specified --sample-frequency (and we'll downsample). (bool, default = false)
|
||||||
|
--allow-upsample : If true, allow the input waveform to have a lower frequency than the specified --sample-frequency (and we'll upsample). (bool, default = false)
|
||||||
|
--blackman-coeff : Constant coefficient for generalized Blackman window. (float, default = 0.42)
|
||||||
|
--channel : Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (int, default = -1)
|
||||||
|
--debug-mel : Print out debugging information for mel bin computation (bool, default = false)
|
||||||
|
--dither : Dithering constant (0.0 means no dither). If you turn this off, you should set the --energy-floor option, e.g. to 1.0 or 0.1 (float, default = 1)
|
||||||
|
--energy-floor : Floor on energy (absolute, not relative) in FBANK computation. Only makes a difference if --use-energy=true; only necessary if --dither=0.0. Suggested values: 0.1 or 1.0 (float, default = 0)
|
||||||
|
--frame-length : Frame length in milliseconds (float, default = 25)
|
||||||
|
--frame-shift : Frame shift in milliseconds (float, default = 10)
|
||||||
|
--high-freq : High cutoff frequency for mel bins (if <= 0, offset from Nyquist) (float, default = 0)
|
||||||
|
--htk-compat : If true, put energy last. Warning: not sufficient to get HTK compatible features (need to change other parameters). (bool, default = false)
|
||||||
|
--low-freq : Low cutoff frequency for mel bins (float, default = 20)
|
||||||
|
--max-feature-vectors : Memory optimization. If larger than 0, periodically remove feature vectors so that only this number of the latest feature vectors is retained. (int, default = -1)
|
||||||
|
--min-duration : Minimum duration of segments to process (in seconds). (float, default = 0)
|
||||||
|
--num-mel-bins : Number of triangular mel-frequency bins (int, default = 23)
|
||||||
|
--output-format : Format of the output files [kaldi, htk] (string, default = "kaldi")
|
||||||
|
--preemphasis-coefficient : Coefficient for use in signal preemphasis (float, default = 0.97)
|
||||||
|
--raw-energy : If true, compute energy before preemphasis and windowing (bool, default = true)
|
||||||
|
--remove-dc-offset : Subtract mean from waveform on each frame (bool, default = true)
|
||||||
|
--round-to-power-of-two : If true, round window size to power of two by zero-padding input to FFT. (bool, default = true)
|
||||||
|
--sample-frequency : Waveform data sample frequency (must match the waveform file, if specified there) (float, default = 16000)
|
||||||
|
--snip-edges : If true, end effects will be handled by outputting only frames that completely fit in the file, and the number of frames depends on the frame-length. If false, the number of frames depends only on the frame-shift, and we reflect the data at the ends. (bool, default = true)
|
||||||
|
--subtract-mean : Subtract mean of each feature file [CMS]; not recommended to do it this way. (bool, default = false)
|
||||||
|
--use-energy : Add an extra dimension with energy to the FBANK output. (bool, default = false)
|
||||||
|
--use-log-fbank : If true, produce log-filterbank, else produce linear. (bool, default = true)
|
||||||
|
--use-power : If true, use power, else use magnitude. (bool, default = true)
|
||||||
|
--utt2spk : Utterance to speaker-id map (if doing VTLN and you have warps per speaker) (string, default = "")
|
||||||
|
--vtln-high : High inflection point in piecewise linear VTLN warping function (if negative, offset from high-mel-freq (float, default = -500)
|
||||||
|
--vtln-low : Low inflection point in piecewise linear VTLN warping function (float, default = 100)
|
||||||
|
--vtln-map : Map from utterance or speaker-id to vtln warp factor (rspecifier) (string, default = "")
|
||||||
|
--vtln-warp : Vtln warp factor (only applicable if vtln-map not specified) (float, default = 1)
|
||||||
|
--window-type : Type of window ("hamming"|"hanning"|"povey"|"rectangular"|"sine"|"blackmann") (string, default = "povey")
|
||||||
|
--write-utt2dur : Wspecifier to write duration of each utterance in seconds, e.g. 'ark,t:utt2dur'. (string, default = "")
|
||||||
|
|
||||||
|
Standard options:
|
||||||
|
--config : Configuration file to read (this option may be repeated) (string, default = "")
|
||||||
|
--help : Print out usage message (bool, default = false)
|
||||||
|
--print-args : Print the command line arguments (to stderr) (bool, default = true)
|
||||||
|
--verbose : Verbose level (higher->more logging) (int, default = 0)
|
||||||
|
|
||||||
|
FbankOptions
|
||||||
|
------------
|
||||||
|
|
||||||
|
``kaldifeat`` reuses the same options from kaldi's ``compute-fbank-feats``.
|
||||||
|
|
||||||
|
The following shows the default values of ``kaldifeat.FbankOptions``:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
>>> import kaldifeat
|
||||||
|
>>> fbank_opts = kaldifeat.FbankOptions()
|
||||||
|
>>> print(fbank_opts)
|
||||||
|
frame_opts:
|
||||||
|
samp_freq: 16000
|
||||||
|
frame_shift_ms: 10
|
||||||
|
frame_length_ms: 25
|
||||||
|
dither: 1
|
||||||
|
preemph_coeff: 0.97
|
||||||
|
remove_dc_offset: 1
|
||||||
|
window_type: povey
|
||||||
|
round_to_power_of_two: 1
|
||||||
|
blackman_coeff: 0.42
|
||||||
|
snip_edges: 1
|
||||||
|
|
||||||
|
|
||||||
|
mel_opts:
|
||||||
|
num_bins: 23
|
||||||
|
low_freq: 20
|
||||||
|
high_freq: 0
|
||||||
|
vtln_low: 100
|
||||||
|
vtln_high: -500
|
||||||
|
debug_mel: 0
|
||||||
|
htk_mode: 0
|
||||||
|
|
||||||
|
use_energy: 0
|
||||||
|
energy_floor: 0
|
||||||
|
raw_energy: 1
|
||||||
|
htk_compat: 0
|
||||||
|
use_log_fbank: 1
|
||||||
|
use_power: 1
|
||||||
|
device: cpu
|
||||||
|
|
||||||
|
It consists of three parts:
|
||||||
|
|
||||||
|
- ``frame_opts``
|
||||||
|
|
||||||
|
Options in this part are accessed by ``frame_opts.xxx``. That is, to access
|
||||||
|
the sample rate, you use:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
>>> fbank_opts = kaldifeat.FbankOptions()
|
||||||
|
>>> print(fbank_opts.frame_opts.samp_freq)
|
||||||
|
16000.0
|
||||||
|
|
||||||
|
- ``mel_opts``
|
||||||
|
|
||||||
|
Options in this part are accessed by ``mel_opts.xxx``. That is, to access
|
||||||
|
the number of mel bins, you use:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
>>> fbank_opts = kaldifeat.FbankOptions()
|
||||||
|
>>> print(fbank_opts.mel_opts.num_bins)
|
||||||
|
23
|
||||||
|
|
||||||
|
- fbank related
|
||||||
|
|
||||||
|
Options in this part are accessed directly. That is, to access the device
|
||||||
|
field, you use:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
>>> print(fbank_opts.device)
|
||||||
|
cpu
|
||||||
|
>>> fbank_opts.device = 'cuda:0'
|
||||||
|
>>> print(fbank_opts.device)
|
||||||
|
cuda:0
|
||||||
|
>>> import torch
|
||||||
|
>>> fbank_opts.device = torch.device('cuda', 0)
|
||||||
|
>>> print(fbank_opts.device)
|
||||||
|
cuda:0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
To change the sample rate to 8000, you can use:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
>>> fbank_opts = kaldifeat.FbankOptions()
|
||||||
|
>>> print(fbank_opts.frame_opts.samp_freq)
|
||||||
|
16000.0
|
||||||
|
>>> fbank_opts.frame_opts.samp_freq = 8000
|
||||||
|
>>> print(fbank_opts.frame_opts.samp_freq)
|
||||||
|
8000.0
|
||||||
|
|
||||||
|
To change ``snip_edges`` to ``False``, you can use:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
>>> fbank_opts.frame_opts.snip_edges = False
|
||||||
|
>>> print(fbank_opts.frame_opts.snip_edges)
|
||||||
|
False
|
||||||
|
|
||||||
|
To change number of mel bins to 80, you can use:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
>>> print(fbank_opts.mel_opts.num_bins)
|
||||||
|
23
|
||||||
|
>>> fbank_opts.mel_opts.num_bins = 80
|
||||||
|
>>> print(fbank_opts.mel_opts.num_bins)
|
||||||
|
80
|
||||||
|
|
||||||
|
To change the device to ``cuda``, you can use:
|
||||||
|
|
||||||
|
|
||||||
|
Fbank
|
||||||
|
-----
|
||||||
|
|
||||||
|
The following shows how to use ``kaldifeat.Fbank`` to compute
|
||||||
|
the fbank features of sound files.
|
||||||
|
|
||||||
|
First, let us generate two sound files using ``sox``:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
# generate a wav of two seconds, containing a sine-wave
|
||||||
|
# swept from 300 Hz to 3300 Hz
|
||||||
|
sox -n -r 16000 -b 16 test.wav synth 1.2 sine 300-3300
|
||||||
|
|
||||||
|
# another sound file with 0.5 seconds
|
||||||
|
sox -n -r 16000 -b 16 test2.wav synth 0.5 sine 300-3300
|
||||||
|
|
||||||
|
.. hint::
|
||||||
|
|
||||||
|
You can find the above two files by visiting the following two links:
|
||||||
|
|
||||||
|
- `test.wav <https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/python/tests/test_data/test.wav>`_
|
||||||
|
- `test2.wav <https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/python/tests/test_data/test2.wav>`_
|
||||||
|
|
||||||
|
The `following code <https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/python/tests/test_fbank.py>`_
|
||||||
|
shows the usage of ``kaldifeat.Fbank``.
|
||||||
|
|
||||||
|
It shows:
|
||||||
|
|
||||||
|
- How to read a sound file. Note that audio samples are scaled to the range [-32768, 32768].
|
||||||
|
The intention is to produce the same output as kaldi. You don't need to scale it if
|
||||||
|
you don't care about the compatibility with kaldi
|
||||||
|
|
||||||
|
- ``kaldifeat.Fbank`` supports CUDA as well as CPU
|
||||||
|
|
||||||
|
- ``kaldifeat.Fbank`` supports processing sound file in a batch as well as accepting
|
||||||
|
a single sound file
|
||||||
|
|
||||||
|
|
||||||
|
.. literalinclude:: ./code/test_fbank.py
|
||||||
|
:caption: Demo of ``kaldifeat.Fbank``
|
||||||
|
:language: python
|
@ -1,3 +1,4 @@
|
|||||||
|
import torch
|
||||||
from _kaldifeat import FbankOptions, FrameExtractionOptions, MelBanksOptions
|
from _kaldifeat import FbankOptions, FrameExtractionOptions, MelBanksOptions
|
||||||
|
|
||||||
from .fbank import Fbank
|
from .fbank import Fbank
|
||||||
|
59
scripts/github_actions/install_cuda.sh
Executable file
59
scripts/github_actions/install_cuda.sh
Executable file
@ -0,0 +1,59 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
echo "cuda version: $cuda"
|
||||||
|
|
||||||
|
case "$cuda" in
|
||||||
|
10.0)
|
||||||
|
url=https://developer.nvidia.com/compute/cuda/10.0/Prod/local_installers/cuda_10.0.130_410.48_linux
|
||||||
|
;;
|
||||||
|
10.1)
|
||||||
|
# WARNING: there are bugs in
|
||||||
|
# https://developer.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.105_418.39_linux.run
|
||||||
|
# with GCC 7. Please use the following version
|
||||||
|
url=http://developer.download.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.243_418.87.00_linux.run
|
||||||
|
;;
|
||||||
|
10.2)
|
||||||
|
url=http://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda_10.2.89_440.33.01_linux.run
|
||||||
|
;;
|
||||||
|
11.0)
|
||||||
|
url=http://developer.download.nvidia.com/compute/cuda/11.0.2/local_installers/cuda_11.0.2_450.51.05_linux.run
|
||||||
|
;;
|
||||||
|
11.1)
|
||||||
|
# url=https://developer.download.nvidia.com/compute/cuda/11.1.0/local_installers/cuda_11.1.0_455.23.05_linux.run
|
||||||
|
url=https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown cuda version: $cuda"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
function retry() {
|
||||||
|
$* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*)
|
||||||
|
}
|
||||||
|
|
||||||
|
retry curl -LSs -O $url
|
||||||
|
filename=$(basename $url)
|
||||||
|
echo "filename: $filename"
|
||||||
|
chmod +x ./$filename
|
||||||
|
sudo ./$filename --toolkit --silent
|
||||||
|
rm -fv ./$filename
|
||||||
|
|
||||||
|
export CUDA_HOME=/usr/local/cuda
|
||||||
|
export PATH=$CUDA_HOME/bin:$PATH
|
||||||
|
export LD_LIBRARY_PATH=$CUDA_HOME/lib:$LD_LIBRARY_PATH
|
||||||
|
export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
|
58
scripts/github_actions/install_cudnn.sh
Executable file
58
scripts/github_actions/install_cudnn.sh
Executable file
@ -0,0 +1,58 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
case $cuda in
|
||||||
|
10.0)
|
||||||
|
filename=cudnn-10.0-linux-x64-v7.6.5.32.tgz
|
||||||
|
url=http://www.mediafire.com/file/1037lb1vmj9qdtq/cudnn-10.0-linux-x64-v7.6.5.32.tgz/file
|
||||||
|
;;
|
||||||
|
10.1)
|
||||||
|
filename=cudnn-10.1-linux-x64-v8.0.2.39.tgz
|
||||||
|
url=http://www.mediafire.com/file/fnl2wg0h757qhd7/cudnn-10.1-linux-x64-v8.0.2.39.tgz/file
|
||||||
|
;;
|
||||||
|
10.2)
|
||||||
|
filename=cudnn-10.2-linux-x64-v8.0.2.39.tgz
|
||||||
|
url=http://www.mediafire.com/file/sc2nvbtyg0f7ien/cudnn-10.2-linux-x64-v8.0.2.39.tgz/file
|
||||||
|
;;
|
||||||
|
11.0)
|
||||||
|
filename=cudnn-11.0-linux-x64-v8.0.5.39.tgz
|
||||||
|
url=https://www.mediafire.com/file/abyhnls106ko9kp/cudnn-11.0-linux-x64-v8.0.5.39.tgz/file
|
||||||
|
;;
|
||||||
|
11.1)
|
||||||
|
filename=cudnn-11.1-linux-x64-v8.0.5.39.tgz
|
||||||
|
url=https://www.mediafire.com/file/qx55zd65773xonv/cudnn-11.1-linux-x64-v8.0.5.39.tgz/file
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unsupported cuda version: $cuda"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
function retry() {
|
||||||
|
$* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*)
|
||||||
|
}
|
||||||
|
|
||||||
|
# It is forked from https://github.com/Juvenal-Yescas/mediafire-dl
|
||||||
|
# https://github.com/Juvenal-Yescas/mediafire-dl/pull/2 changes the filename and breaks the CI.
|
||||||
|
# We use a separate fork to keep the link fixed.
|
||||||
|
retry wget https://raw.githubusercontent.com/csukuangfj/mediafire-dl/master/mediafire_dl.py
|
||||||
|
|
||||||
|
sed -i 's/quiet=False/quiet=True/' mediafire_dl.py
|
||||||
|
retry python3 mediafire_dl.py "$url"
|
||||||
|
sudo tar xf ./$filename -C /usr/local
|
||||||
|
rm -v ./$filename
|
||||||
|
|
||||||
|
sudo sed -i '59i#define CUDNN_MAJOR 8' /usr/local/cuda/include/cudnn.h
|
108
scripts/github_actions/install_torch.sh
Executable file
108
scripts/github_actions/install_torch.sh
Executable file
@ -0,0 +1,108 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
case ${torch} in
|
||||||
|
1.5.*)
|
||||||
|
case ${cuda} in
|
||||||
|
10.1)
|
||||||
|
package="torch==${torch}+cu101"
|
||||||
|
url=https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
;;
|
||||||
|
10.2)
|
||||||
|
package="torch==${torch}"
|
||||||
|
# Leave url empty to use PyPI.
|
||||||
|
# torch_stable provides cu92 but we want cu102
|
||||||
|
url=
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
;;
|
||||||
|
1.6.0)
|
||||||
|
case ${cuda} in
|
||||||
|
10.1)
|
||||||
|
package="torch==1.6.0+cu101"
|
||||||
|
url=https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
;;
|
||||||
|
10.2)
|
||||||
|
package="torch==1.6.0"
|
||||||
|
# Leave it empty to use PyPI.
|
||||||
|
# torch_stable provides cu92 but we want cu102
|
||||||
|
url=
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
;;
|
||||||
|
1.7.*)
|
||||||
|
case ${cuda} in
|
||||||
|
10.1)
|
||||||
|
package="torch==${torch}+cu101"
|
||||||
|
url=https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
;;
|
||||||
|
10.2)
|
||||||
|
package="torch==${torch}"
|
||||||
|
# Leave it empty to use PyPI.
|
||||||
|
# torch_stable provides cu92 but we want cu102
|
||||||
|
url=
|
||||||
|
;;
|
||||||
|
11.0)
|
||||||
|
package="torch==${torch}+cu110"
|
||||||
|
url=https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
;;
|
||||||
|
1.8.*)
|
||||||
|
case ${cuda} in
|
||||||
|
10.1)
|
||||||
|
package="torch==${torch}+cu101"
|
||||||
|
url=https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
;;
|
||||||
|
10.2)
|
||||||
|
package="torch==${torch}"
|
||||||
|
# Leave it empty to use PyPI.
|
||||||
|
url=
|
||||||
|
;;
|
||||||
|
11.1)
|
||||||
|
package="torch==${torch}+cu111"
|
||||||
|
url=https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
;;
|
||||||
|
1.9.0)
|
||||||
|
case ${cuda} in
|
||||||
|
10.2)
|
||||||
|
package="torch==${torch}"
|
||||||
|
# Leave it empty to use PyPI.
|
||||||
|
url=
|
||||||
|
;;
|
||||||
|
11.1)
|
||||||
|
package="torch==${torch}+cu111"
|
||||||
|
url=https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unsupported PyTorch version: ${torch}"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
function retry() {
|
||||||
|
$* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*)
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ x"${url}" == "x" ]; then
|
||||||
|
retry python3 -m pip install -q $package
|
||||||
|
else
|
||||||
|
retry python3 -m pip install -q $package -f $url
|
||||||
|
fi
|
Loading…
x
Reference in New Issue
Block a user