Merge pull request #2 from csukuangfj/pip

Publish to PyPI.
This commit is contained in:
Fangjun Kuang 2021-07-16 21:25:28 +08:00 committed by GitHub
commit bac4db61c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 872 additions and 25 deletions

View File

@ -1,3 +1,19 @@
# Copyright 2021 Xiaomi Corp. (author: Fangjun Kuang)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
name: Publish to PyPI
on:
@ -7,30 +23,85 @@ on:
jobs:
pypi:
runs-on: ubuntu-latest
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-18.04]
cuda: ["10.1"]
gcc: ["5"]
torch: ["1.8.1"]
python-version: [3.6, 3.7, 3.8]
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python
uses: actions/setup-python@v2
with:
python-version: 3.6
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
- name: Install CUDA Toolkit ${{ matrix.cuda }}
shell: bash
env:
cuda: ${{ matrix.cuda }}
run: |
source ./scripts/github_actions/install_cuda.sh
echo "CUDA_HOME=${CUDA_HOME}" >> $GITHUB_ENV
echo "${CUDA_HOME}/bin" >> $GITHUB_PATH
echo "LD_LIBRARY_PATH=${CUDA_HOME}/lib:${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}" >> $GITHUB_ENV
- name: Display NVCC version
run: |
which nvcc
nvcc --version
- name: Install GCC ${{ matrix.gcc }}
run: |
sudo apt-get install -y gcc-${{ matrix.gcc }} g++-${{ matrix.gcc }}
echo "CC=/usr/bin/gcc-${{ matrix.gcc }}" >> $GITHUB_ENV
echo "CXX=/usr/bin/g++-${{ matrix.gcc }}" >> $GITHUB_ENV
echo "CUDAHOSTCXX=/usr/bin/g++-${{ matrix.gcc }}" >> $GITHUB_ENV
- name: Install PyTorch ${{ matrix.torch }}
env:
cuda: ${{ matrix.cuda }}
torch: ${{ matrix.torch }}
shell: bash
run: |
python3 -m pip install --upgrade pip
python3 -m pip install wheel twine setuptools
python3 -m pip install wheel twine typing_extensions
python3 -m pip install bs4 requests tqdm
- name: Build
shell: bash
./scripts/github_actions/install_torch.sh
python3 -c "import torch; print('torch version:', torch.__version__)"
- name: Download cudnn 8.0
env:
cuda: ${{ matrix.cuda }}
run: |
python3 setup.py sdist
ls -l dist/*
./scripts/github_actions/install_cudnn.sh
- name: Build pip packages
shell: bash
env:
KALDIFEAT_IS_FOR_PYPI: 1
run: |
tag=$(python3 -c "import sys; print(''.join(sys.version[:3].split('.')))")
export KALDIFEAT_MAKE_ARGS="-j2"
python3 setup.py bdist_wheel --python-tag=py${tag}
ls -lh dist/
- name: Publish wheels to PyPI
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
twine upload dist/kaldifeat-*.tar.gz
twine upload dist/kaldifeat-*.whl
- name: Upload Wheel
uses: actions/upload-artifact@v2
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-${{ matrix.os }}
path: dist/*.whl

View File

@ -4,20 +4,36 @@ Wrap kaldi's feature computations to Python with PyTorch support.
# Installation
`kaldifeat` can be installed by
## From PyPi with pip
If you install `kaldifeat` using `pip`, it will also install
PyTorch 1.8.1. If this is not what you want, please install `kaldifeat`
from source (see below).
```bash
pip install kaldifeat
```
# TODOs
## From source
- [ ] Add Python interface
- [ ] Support torch.device so that it can switch between CUDA and CPU
- [ ] Add unit tests
- [ ] Set up GitHub actions
- [ ] Benchmark its speed and compare it with Kaldi
- [ ] Support batch processing of multiple waves
- [ ] Handle non-default parameters
- [ ] Support MFCC and other features available in Kaldi
- [ ] Publish it to PyPI
The following are the commands to compile `kaldifeat` from source.
We assume that you have installed `cmake` and PyTorch.
cmake 3.11 is known to work. Other cmake versions may also work.
PyTorch 1.8.1 is known to work. Other PyTorch versions may also work.
```bash
mkdir /some/path
git clone https://github.com/csukuangfj/kaldifeat.git
cd kaldifeat
python setup.py install
```
To test whether `kaldifeat` was installed successfully, you can run:
```
python3 -c "import kaldifeat; print(kaldifeat.__version__)"
```
## Usage
Please refer to <https://kaldifeat.readthedocs.io/en/latest/usage.html>
for how to use `kaldifeat`.

View File

@ -9,6 +9,12 @@ from pathlib import Path
import setuptools
from setuptools.command.build_ext import build_ext
def is_for_pypi():
ans = os.environ.get("KALDIFEAT_IS_FOR_PYPI", None)
return ans is not None
try:
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
@ -17,11 +23,12 @@ try:
_bdist_wheel.finalize_options(self)
# In this case, the generated wheel has a name in the form
# k2-xxx-pyxx-none-any.whl
# self.root_is_pure = True
# The generated wheel has a name ending with
# -linux_x86_64.whl
self.root_is_pure = False
if is_for_pypi():
self.root_is_pure = True
else:
# The generated wheel has a name ending with
# -linux_x86_64.whl
self.root_is_pure = False
except ImportError:

20
doc/Makefile Normal file
View File

@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

35
doc/make.bat Normal file
View File

@ -0,0 +1,35 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd

6
doc/requirements.txt Normal file
View File

@ -0,0 +1,6 @@
dataclasses
recommonmark
sphinx
sphinx-autodoc-typehints
sphinx_rtd_theme
sphinxcontrib-bibtex

72
doc/source/code/test_fbank.py Executable file
View File

@ -0,0 +1,72 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang)
import numpy as np
import soundfile as sf
import torch
import kaldifeat
def read_wave(filename) -> torch.Tensor:
"""Read a wave file and return it as a 1-D tensor.
Note:
You don't need to scale it to [-32768, 32767].
We use scaling here to follow the approach in Kaldi.
Args:
filename:
Filename of a sound file.
Returns:
Return a 1-D tensor containing audio samples.
"""
with sf.SoundFile(filename) as sf_desc:
sampling_rate = sf_desc.samplerate
assert sampling_rate == 16000
data = sf_desc.read(dtype=np.float32, always_2d=False)
data *= 32768
return torch.from_numpy(data)
def test_fbank():
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
wave0 = read_wave("test_data/test.wav")
wave1 = read_wave("test_data/test2.wav")
wave0 = wave0.to(device)
wave1 = wave1.to(device)
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.device = device
fbank = kaldifeat.Fbank(opts)
# We can compute fbank features in batches
features = fbank([wave0, wave1])
assert isinstance(features, list), f"{type(features)}"
assert len(features) == 2
# We can also compute fbank features for a single wave
features0 = fbank(wave0)
features1 = fbank(wave1)
assert torch.allclose(features[0], features0)
assert torch.allclose(features[1], features1)
# To compute fbank features for only a specified frame
audio_frames = fbank.convert_samples_to_frames(wave0)
feature_frame_1 = fbank.compute(audio_frames[1])
feature_frame_10 = fbank.compute(audio_frames[10])
assert torch.allclose(features0[1], feature_frame_1)
assert torch.allclose(features0[10], feature_frame_10)
if __name__ == "__main__":
test_fbank()

104
doc/source/conf.py Normal file
View File

@ -0,0 +1,104 @@
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
import re
import sphinx_rtd_theme
# import sys
# sys.path.insert(0, os.path.abspath('.'))
# -- Project information -----------------------------------------------------
project = "kaldifeat"
copyright = "2021, Fangjun Kuang"
author = "Fangjun Kuang"
def get_version():
cmake_file = "../../CMakeLists.txt"
with open(cmake_file) as f:
content = f.read()
version = re.search(r"set\(kaldifeat_VERSION (.*)\)", content).group(1)
return version.strip('"')
version = get_version()
release = version
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
"recommonmark",
"sphinx.ext.autodoc",
"sphinx.ext.githubpages",
"sphinx.ext.napoleon",
"sphinx_autodoc_typehints",
"sphinx_rtd_theme",
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
source_suffix = {
".rst": "restructuredtext",
".md": "markdown",
}
master_doc = "index"
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = "sphinx_rtd_theme"
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
html_show_sourcelink = True
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
pygments_style = "sphinx"
numfig = True
html_context = {
"display_github": True,
"github_user": "csukuangfj",
"github_repo": "kaldifeat",
"github_version": "master",
"conf_py_path": "/kaldifeat/docs/source/",
}
# refer to
# https://sphinx-rtd-theme.readthedocs.io/en/latest/configuring.html
html_theme_options = {
"logo_only": False,
"display_version": True,
"prev_next_buttons_location": "bottom",
"style_external_links": True,
}

24
doc/source/index.rst Normal file
View File

@ -0,0 +1,24 @@
.. kaldifeat documentation master file, created by
sphinx-quickstart on Fri Jul 16 20:15:27 2021.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
kaldifeat
=========
`kaldifeat <https://github.com/csukuangfj/kaldifeat>`_ implements
feature extraction algorithms **compatible** with kaldi using PyTorch, supporting CUDA
as well as autograd.
Currently, only fbank features are supported.
It can produce the same feature output as ``compute-fbank-feats`` (from kaldi)
when given the same options.
.. toctree::
:maxdepth: 2
:caption: Contents:
installation
usage

View File

@ -0,0 +1,54 @@
Installation
============
.. _from source:
Install kaldifeat from source
-----------------------------
You have to install ``cmake`` and ``PyTorch`` first.
- ``cmake`` 3.11 is known to work. Other CMake versions may also work.
- ``PyTorch`` 1.8.1 is known to work. Other PyTorch versions may also work.
- Python >= 3.6
The commands to install ``kaldifeat`` from source are:
.. code-block:: bash
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
python3 setup.py install
To test that you have installed ``kaldifeat`` successfully, please run:
.. code-block:: bash
python3 -c "import kaldifeat; print(kaldifeat.__version__)"
It should print the version, e.g., ``1.0``.
Install kaldifeat from PyPI
---------------------------
The pre-built ``kaldifeat`` hosted on PyPI uses PyTorch 1.8.1.
If you install ``kaldifeat`` using pip, it will replace your locally
installed PyTorch automatically with PyTorch 1.8.1.
If you don't want this happen, please `Install kaldifeat from source`_.
The command to install ``kaldifeat`` from PyPI is:
.. code-block:: bash
pip install kaldifeat
To test that you have installed ``kaldifeat`` successfully, please run:
.. code-block:: bash
python3 -c "import kaldifeat; print(kaldifeat.__version__)"
It should print the version, e.g., ``1.0``.

212
doc/source/usage.rst Normal file
View File

@ -0,0 +1,212 @@
Usage
=====
Let us first see the help message of kaldi's ``compute-fbank-feats``:
.. code-block:: bash
$ compute-fbank-feats
Create Mel-filter bank (FBANK) feature files.
Usage: compute-fbank-feats [options...] <wav-rspecifier> <feats-wspecifier>
Options:
--allow-downsample : If true, allow the input waveform to have a higher frequency than the specified --sample-frequency (and we'll downsample). (bool, default = false)
--allow-upsample : If true, allow the input waveform to have a lower frequency than the specified --sample-frequency (and we'll upsample). (bool, default = false)
--blackman-coeff : Constant coefficient for generalized Blackman window. (float, default = 0.42)
--channel : Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (int, default = -1)
--debug-mel : Print out debugging information for mel bin computation (bool, default = false)
--dither : Dithering constant (0.0 means no dither). If you turn this off, you should set the --energy-floor option, e.g. to 1.0 or 0.1 (float, default = 1)
--energy-floor : Floor on energy (absolute, not relative) in FBANK computation. Only makes a difference if --use-energy=true; only necessary if --dither=0.0. Suggested values: 0.1 or 1.0 (float, default = 0)
--frame-length : Frame length in milliseconds (float, default = 25)
--frame-shift : Frame shift in milliseconds (float, default = 10)
--high-freq : High cutoff frequency for mel bins (if <= 0, offset from Nyquist) (float, default = 0)
--htk-compat : If true, put energy last. Warning: not sufficient to get HTK compatible features (need to change other parameters). (bool, default = false)
--low-freq : Low cutoff frequency for mel bins (float, default = 20)
--max-feature-vectors : Memory optimization. If larger than 0, periodically remove feature vectors so that only this number of the latest feature vectors is retained. (int, default = -1)
--min-duration : Minimum duration of segments to process (in seconds). (float, default = 0)
--num-mel-bins : Number of triangular mel-frequency bins (int, default = 23)
--output-format : Format of the output files [kaldi, htk] (string, default = "kaldi")
--preemphasis-coefficient : Coefficient for use in signal preemphasis (float, default = 0.97)
--raw-energy : If true, compute energy before preemphasis and windowing (bool, default = true)
--remove-dc-offset : Subtract mean from waveform on each frame (bool, default = true)
--round-to-power-of-two : If true, round window size to power of two by zero-padding input to FFT. (bool, default = true)
--sample-frequency : Waveform data sample frequency (must match the waveform file, if specified there) (float, default = 16000)
--snip-edges : If true, end effects will be handled by outputting only frames that completely fit in the file, and the number of frames depends on the frame-length. If false, the number of frames depends only on the frame-shift, and we reflect the data at the ends. (bool, default = true)
--subtract-mean : Subtract mean of each feature file [CMS]; not recommended to do it this way. (bool, default = false)
--use-energy : Add an extra dimension with energy to the FBANK output. (bool, default = false)
--use-log-fbank : If true, produce log-filterbank, else produce linear. (bool, default = true)
--use-power : If true, use power, else use magnitude. (bool, default = true)
--utt2spk : Utterance to speaker-id map (if doing VTLN and you have warps per speaker) (string, default = "")
--vtln-high : High inflection point in piecewise linear VTLN warping function (if negative, offset from high-mel-freq (float, default = -500)
--vtln-low : Low inflection point in piecewise linear VTLN warping function (float, default = 100)
--vtln-map : Map from utterance or speaker-id to vtln warp factor (rspecifier) (string, default = "")
--vtln-warp : Vtln warp factor (only applicable if vtln-map not specified) (float, default = 1)
--window-type : Type of window ("hamming"|"hanning"|"povey"|"rectangular"|"sine"|"blackmann") (string, default = "povey")
--write-utt2dur : Wspecifier to write duration of each utterance in seconds, e.g. 'ark,t:utt2dur'. (string, default = "")
Standard options:
--config : Configuration file to read (this option may be repeated) (string, default = "")
--help : Print out usage message (bool, default = false)
--print-args : Print the command line arguments (to stderr) (bool, default = true)
--verbose : Verbose level (higher->more logging) (int, default = 0)
FbankOptions
------------
``kaldifeat`` reuses the same options from kaldi's ``compute-fbank-feats``.
The following shows the default values of ``kaldifeat.FbankOptions``:
.. code-block:: python
>>> import kaldifeat
>>> fbank_opts = kaldifeat.FbankOptions()
>>> print(fbank_opts)
frame_opts:
samp_freq: 16000
frame_shift_ms: 10
frame_length_ms: 25
dither: 1
preemph_coeff: 0.97
remove_dc_offset: 1
window_type: povey
round_to_power_of_two: 1
blackman_coeff: 0.42
snip_edges: 1
mel_opts:
num_bins: 23
low_freq: 20
high_freq: 0
vtln_low: 100
vtln_high: -500
debug_mel: 0
htk_mode: 0
use_energy: 0
energy_floor: 0
raw_energy: 1
htk_compat: 0
use_log_fbank: 1
use_power: 1
device: cpu
It consists of three parts:
- ``frame_opts``
Options in this part are accessed by ``frame_opts.xxx``. That is, to access
the sample rate, you use:
.. code-block:: python
>>> fbank_opts = kaldifeat.FbankOptions()
>>> print(fbank_opts.frame_opts.samp_freq)
16000.0
- ``mel_opts``
Options in this part are accessed by ``mel_opts.xxx``. That is, to access
the number of mel bins, you use:
.. code-block:: python
>>> fbank_opts = kaldifeat.FbankOptions()
>>> print(fbank_opts.mel_opts.num_bins)
23
- fbank related
Options in this part are accessed directly. That is, to access the device
field, you use:
.. code-block::
>>> print(fbank_opts.device)
cpu
>>> fbank_opts.device = 'cuda:0'
>>> print(fbank_opts.device)
cuda:0
>>> import torch
>>> fbank_opts.device = torch.device('cuda', 0)
>>> print(fbank_opts.device)
cuda:0
To change the sample rate to 8000, you can use:
.. code-block:: python
>>> fbank_opts = kaldifeat.FbankOptions()
>>> print(fbank_opts.frame_opts.samp_freq)
16000.0
>>> fbank_opts.frame_opts.samp_freq = 8000
>>> print(fbank_opts.frame_opts.samp_freq)
8000.0
To change ``snip_edges`` to ``False``, you can use:
.. code-block:: python
>>> fbank_opts.frame_opts.snip_edges = False
>>> print(fbank_opts.frame_opts.snip_edges)
False
To change number of mel bins to 80, you can use:
.. code-block:: python
>>> print(fbank_opts.mel_opts.num_bins)
23
>>> fbank_opts.mel_opts.num_bins = 80
>>> print(fbank_opts.mel_opts.num_bins)
80
To change the device to ``cuda``, you can use:
Fbank
-----
The following shows how to use ``kaldifeat.Fbank`` to compute
the fbank features of sound files.
First, let us generate two sound files using ``sox``:
.. code-block:: bash
# generate a wav of two seconds, containing a sine-wave
# swept from 300 Hz to 3300 Hz
sox -n -r 16000 -b 16 test.wav synth 1.2 sine 300-3300
# another sound file with 0.5 seconds
sox -n -r 16000 -b 16 test2.wav synth 0.5 sine 300-3300
.. hint::
You can find the above two files by visiting the following two links:
- `test.wav <https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/python/tests/test_data/test.wav>`_
- `test2.wav <https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/python/tests/test_data/test2.wav>`_
The `following code <https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/python/tests/test_fbank.py>`_
shows the usage of ``kaldifeat.Fbank``.
It shows:
- How to read a sound file. Note that audio samples are scaled to the range [-32768, 32768].
The intention is to produce the same output as kaldi. You don't need to scale it if
you don't care about the compatibility with kaldi
- ``kaldifeat.Fbank`` supports CUDA as well as CPU
- ``kaldifeat.Fbank`` supports processing sound file in a batch as well as accepting
a single sound file
.. literalinclude:: ./code/test_fbank.py
:caption: Demo of ``kaldifeat.Fbank``
:language: python

View File

@ -1,3 +1,4 @@
import torch
from _kaldifeat import FbankOptions, FrameExtractionOptions, MelBanksOptions
from .fbank import Fbank

View File

@ -0,0 +1,59 @@
#!/bin/bash
#
# Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
echo "cuda version: $cuda"
case "$cuda" in
10.0)
url=https://developer.nvidia.com/compute/cuda/10.0/Prod/local_installers/cuda_10.0.130_410.48_linux
;;
10.1)
# WARNING: there are bugs in
# https://developer.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.105_418.39_linux.run
# with GCC 7. Please use the following version
url=http://developer.download.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.243_418.87.00_linux.run
;;
10.2)
url=http://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda_10.2.89_440.33.01_linux.run
;;
11.0)
url=http://developer.download.nvidia.com/compute/cuda/11.0.2/local_installers/cuda_11.0.2_450.51.05_linux.run
;;
11.1)
# url=https://developer.download.nvidia.com/compute/cuda/11.1.0/local_installers/cuda_11.1.0_455.23.05_linux.run
url=https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run
;;
*)
echo "Unknown cuda version: $cuda"
exit 1
;;
esac
function retry() {
$* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*)
}
retry curl -LSs -O $url
filename=$(basename $url)
echo "filename: $filename"
chmod +x ./$filename
sudo ./$filename --toolkit --silent
rm -fv ./$filename
export CUDA_HOME=/usr/local/cuda
export PATH=$CUDA_HOME/bin:$PATH
export LD_LIBRARY_PATH=$CUDA_HOME/lib:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH

View File

@ -0,0 +1,58 @@
#!/bin/bash
#
# Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
case $cuda in
10.0)
filename=cudnn-10.0-linux-x64-v7.6.5.32.tgz
url=http://www.mediafire.com/file/1037lb1vmj9qdtq/cudnn-10.0-linux-x64-v7.6.5.32.tgz/file
;;
10.1)
filename=cudnn-10.1-linux-x64-v8.0.2.39.tgz
url=http://www.mediafire.com/file/fnl2wg0h757qhd7/cudnn-10.1-linux-x64-v8.0.2.39.tgz/file
;;
10.2)
filename=cudnn-10.2-linux-x64-v8.0.2.39.tgz
url=http://www.mediafire.com/file/sc2nvbtyg0f7ien/cudnn-10.2-linux-x64-v8.0.2.39.tgz/file
;;
11.0)
filename=cudnn-11.0-linux-x64-v8.0.5.39.tgz
url=https://www.mediafire.com/file/abyhnls106ko9kp/cudnn-11.0-linux-x64-v8.0.5.39.tgz/file
;;
11.1)
filename=cudnn-11.1-linux-x64-v8.0.5.39.tgz
url=https://www.mediafire.com/file/qx55zd65773xonv/cudnn-11.1-linux-x64-v8.0.5.39.tgz/file
;;
*)
echo "Unsupported cuda version: $cuda"
exit 1
;;
esac
function retry() {
$* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*)
}
# It is forked from https://github.com/Juvenal-Yescas/mediafire-dl
# https://github.com/Juvenal-Yescas/mediafire-dl/pull/2 changes the filename and breaks the CI.
# We use a separate fork to keep the link fixed.
retry wget https://raw.githubusercontent.com/csukuangfj/mediafire-dl/master/mediafire_dl.py
sed -i 's/quiet=False/quiet=True/' mediafire_dl.py
retry python3 mediafire_dl.py "$url"
sudo tar xf ./$filename -C /usr/local
rm -v ./$filename
sudo sed -i '59i#define CUDNN_MAJOR 8' /usr/local/cuda/include/cudnn.h

View File

@ -0,0 +1,108 @@
#!/bin/bash
#
# Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
case ${torch} in
1.5.*)
case ${cuda} in
10.1)
package="torch==${torch}+cu101"
url=https://download.pytorch.org/whl/torch_stable.html
;;
10.2)
package="torch==${torch}"
# Leave url empty to use PyPI.
# torch_stable provides cu92 but we want cu102
url=
;;
esac
;;
1.6.0)
case ${cuda} in
10.1)
package="torch==1.6.0+cu101"
url=https://download.pytorch.org/whl/torch_stable.html
;;
10.2)
package="torch==1.6.0"
# Leave it empty to use PyPI.
# torch_stable provides cu92 but we want cu102
url=
;;
esac
;;
1.7.*)
case ${cuda} in
10.1)
package="torch==${torch}+cu101"
url=https://download.pytorch.org/whl/torch_stable.html
;;
10.2)
package="torch==${torch}"
# Leave it empty to use PyPI.
# torch_stable provides cu92 but we want cu102
url=
;;
11.0)
package="torch==${torch}+cu110"
url=https://download.pytorch.org/whl/torch_stable.html
;;
esac
;;
1.8.*)
case ${cuda} in
10.1)
package="torch==${torch}+cu101"
url=https://download.pytorch.org/whl/torch_stable.html
;;
10.2)
package="torch==${torch}"
# Leave it empty to use PyPI.
url=
;;
11.1)
package="torch==${torch}+cu111"
url=https://download.pytorch.org/whl/torch_stable.html
;;
esac
;;
1.9.0)
case ${cuda} in
10.2)
package="torch==${torch}"
# Leave it empty to use PyPI.
url=
;;
11.1)
package="torch==${torch}+cu111"
url=https://download.pytorch.org/whl/torch_stable.html
;;
esac
;;
*)
echo "Unsupported PyTorch version: ${torch}"
exit 1
;;
esac
function retry() {
$* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*)
}
if [ x"${url}" == "x" ]; then
retry python3 -m pip install -q $package
else
retry python3 -m pip install -q $package -f $url
fi