mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
support exporting to ncnn format via PNNX (#571)
This commit is contained in:
parent
436942211c
commit
099cd3a215
160
.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
vendored
Executable file
160
.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
vendored
Executable file
@ -0,0 +1,160 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
|
||||||
|
|
||||||
|
log "Downloading pre-trained model from $repo_url"
|
||||||
|
git lfs install
|
||||||
|
git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
log "Display test files"
|
||||||
|
tree $repo/
|
||||||
|
soxi $repo/test_wavs/*.wav
|
||||||
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
|
pushd $repo/exp
|
||||||
|
ln -s pretrained-iter-468000-avg-16.pt pretrained.pt
|
||||||
|
ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
log "Install ncnn and pnnx"
|
||||||
|
|
||||||
|
# We are using a modified ncnn here. Will try to merge it to the official repo
|
||||||
|
# of ncnn
|
||||||
|
git clone https://github.com/csukuangfj/ncnn
|
||||||
|
pushd ncnn
|
||||||
|
git submodule init
|
||||||
|
git submodule update python/pybind11
|
||||||
|
python3 setup.py bdist_wheel
|
||||||
|
ls -lh dist/
|
||||||
|
pip install dist/*.whl
|
||||||
|
cd tools/pnnx
|
||||||
|
mkdir build
|
||||||
|
cd build
|
||||||
|
cmake ..
|
||||||
|
make -j4 pnnx
|
||||||
|
|
||||||
|
./src/pnnx || echo "pass"
|
||||||
|
|
||||||
|
popd
|
||||||
|
|
||||||
|
log "Test exporting to pnnx format"
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/export.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--pnnx 1
|
||||||
|
|
||||||
|
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
|
||||||
|
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
|
||||||
|
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/ncnn-decode.py \
|
||||||
|
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
|
||||||
|
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
|
||||||
|
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
|
||||||
|
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
|
||||||
|
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
|
||||||
|
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/streaming-ncnn-decode.py \
|
||||||
|
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
|
||||||
|
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
|
||||||
|
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
|
||||||
|
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
|
||||||
|
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
|
||||||
|
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
log "Test exporting with torch.jit.trace()"
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/export.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--jit-trace 1
|
||||||
|
|
||||||
|
log "Decode with models exported by torch.jit.trace()"
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/jit_pretrained.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--encoder-model-filename $repo/exp/encoder_jit_trace.pt \
|
||||||
|
--decoder-model-filename $repo/exp/decoder_jit_trace.pt \
|
||||||
|
--joiner-model-filename $repo/exp/joiner_jit_trace.pt \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|
||||||
|
for sym in 1 2 3; do
|
||||||
|
log "Greedy search with --max-sym-per-frame $sym"
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/pretrained.py \
|
||||||
|
--method greedy_search \
|
||||||
|
--max-sym-per-frame $sym \
|
||||||
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
done
|
||||||
|
|
||||||
|
for method in modified_beam_search beam_search fast_beam_search; do
|
||||||
|
log "$method"
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/pretrained.py \
|
||||||
|
--method $method \
|
||||||
|
--beam-size 4 \
|
||||||
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
||||||
|
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
||||||
|
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"ncnn" ]]; then
|
||||||
|
mkdir -p lstm_transducer_stateless2/exp
|
||||||
|
ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
|
||||||
|
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
||||||
|
|
||||||
|
ls -lh data
|
||||||
|
ls -lh lstm_transducer_stateless2/exp
|
||||||
|
|
||||||
|
log "Decoding test-clean and test-other"
|
||||||
|
|
||||||
|
# use a small value for decoding with CPU
|
||||||
|
max_duration=100
|
||||||
|
|
||||||
|
for method in greedy_search fast_beam_search modified_beam_search; do
|
||||||
|
log "Decoding with $method"
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/decode.py \
|
||||||
|
--decoding-method $method \
|
||||||
|
--epoch 999 \
|
||||||
|
--avg 1 \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--max-duration $max_duration \
|
||||||
|
--exp-dir lstm_transducer_stateless2/exp
|
||||||
|
done
|
||||||
|
|
||||||
|
rm lstm_transducer_stateless2/exp/*.pt
|
||||||
|
fi
|
136
.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
vendored
Normal file
136
.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
vendored
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
name: run-librispeech-lstm-transducer-2022-09-03
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
pull_request:
|
||||||
|
types: [labeled]
|
||||||
|
|
||||||
|
schedule:
|
||||||
|
# minute (0-59)
|
||||||
|
# hour (0-23)
|
||||||
|
# day of the month (1-31)
|
||||||
|
# month (1-12)
|
||||||
|
# day of the week (0-6)
|
||||||
|
# nightly build at 15:50 UTC time every day
|
||||||
|
- cron: "50 15 * * *"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
run_librispeech_pruned_transducer_stateless3_2022_05_13:
|
||||||
|
if: github.event.label.name == 'ncnn' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-18.04]
|
||||||
|
python-version: [3.8]
|
||||||
|
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Setup Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
cache: 'pip'
|
||||||
|
cache-dependency-path: '**/requirements-ci.txt'
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
run: |
|
||||||
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
|
pip uninstall -y protobuf
|
||||||
|
pip install --no-binary protobuf protobuf
|
||||||
|
|
||||||
|
- name: Cache kaldifeat
|
||||||
|
id: my-cache
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/tmp/kaldifeat
|
||||||
|
key: cache-tmp-${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install kaldifeat
|
||||||
|
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/install-kaldifeat.sh
|
||||||
|
|
||||||
|
- name: Cache LibriSpeech test-clean and test-other datasets
|
||||||
|
id: libri-test-clean-and-test-other-data
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/tmp/download
|
||||||
|
key: cache-libri-test-clean-and-test-other
|
||||||
|
|
||||||
|
- name: Download LibriSpeech test-clean and test-other
|
||||||
|
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
|
||||||
|
|
||||||
|
- name: Prepare manifests for LibriSpeech test-clean and test-other
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
|
||||||
|
|
||||||
|
- name: Cache LibriSpeech test-clean and test-other fbank features
|
||||||
|
id: libri-test-clean-and-test-other-fbank
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/tmp/fbank-libri
|
||||||
|
key: cache-libri-fbank-test-clean-and-test-other-v2
|
||||||
|
|
||||||
|
- name: Compute fbank for LibriSpeech test-clean and test-other
|
||||||
|
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
|
||||||
|
|
||||||
|
- name: Inference with pre-trained model
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||||
|
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||||
|
run: |
|
||||||
|
mkdir -p egs/librispeech/ASR/data
|
||||||
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
|
sudo apt-get -qq install git-lfs tree sox
|
||||||
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|
||||||
|
.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
|
||||||
|
|
||||||
|
- name: Display decoding results for lstm_transducer_stateless2
|
||||||
|
if: github.event_name == 'schedule' || github.event.label.name == 'ncnn'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
tree lstm_transducer_stateless2/exp
|
||||||
|
cd lstm_transducer_stateless2/exp
|
||||||
|
echo "===greedy search==="
|
||||||
|
find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
echo "===fast_beam_search==="
|
||||||
|
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
echo "===modified beam search==="
|
||||||
|
find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
- name: Upload decoding results for lstm_transducer_stateless2
|
||||||
|
uses: actions/upload-artifact@v2
|
||||||
|
if: github.event_name == 'schedule' || github.event.label.name == 'ncnn'
|
||||||
|
with:
|
||||||
|
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03
|
||||||
|
path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -11,3 +11,5 @@ log
|
|||||||
*.bak
|
*.bak
|
||||||
*-bak
|
*-bak
|
||||||
*bak.py
|
*bak.py
|
||||||
|
*.param
|
||||||
|
*.bin
|
||||||
|
Binary file not shown.
After Width: | Height: | Size: 413 KiB |
@ -6,3 +6,4 @@ LibriSpeech
|
|||||||
|
|
||||||
tdnn_lstm_ctc
|
tdnn_lstm_ctc
|
||||||
conformer_ctc
|
conformer_ctc
|
||||||
|
lstm_pruned_stateless_transducer
|
||||||
|
@ -0,0 +1,625 @@
|
|||||||
|
Transducer
|
||||||
|
==========
|
||||||
|
|
||||||
|
.. hint::
|
||||||
|
|
||||||
|
Please scroll down to the bottom of this page to find download links
|
||||||
|
for pretrained models if you don't want to train a model from scratch.
|
||||||
|
|
||||||
|
|
||||||
|
This tutorial shows you how to train a transducer model
|
||||||
|
with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
|
||||||
|
|
||||||
|
We use pruned RNN-T to compute the loss.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
You can find the paper about pruned RNN-T at the following address:
|
||||||
|
|
||||||
|
`<https://arxiv.org/abs/2206.13236>`_
|
||||||
|
|
||||||
|
The transducer model consists of 3 parts:
|
||||||
|
|
||||||
|
- Encoder, a.k.a, transcriber. We use an LSTM model
|
||||||
|
- Decoder, a.k.a, predictor. We use a model consisting of ``nn.Embedding``
|
||||||
|
and ``nn.Conv1d``
|
||||||
|
- Joiner, a.k.a, the joint network.
|
||||||
|
|
||||||
|
.. caution::
|
||||||
|
|
||||||
|
Contrary to the conventional RNN-T models, we use a stateless decoder.
|
||||||
|
That is, it has no recurrent connections.
|
||||||
|
|
||||||
|
.. hint::
|
||||||
|
|
||||||
|
Since the encoder model is an LSTM, not Transformer/Conformer, the
|
||||||
|
resulting model is suitable for streaming/online ASR.
|
||||||
|
|
||||||
|
|
||||||
|
Which model to use
|
||||||
|
------------------
|
||||||
|
|
||||||
|
Currently, there are two folders about LSTM stateless transducer training:
|
||||||
|
|
||||||
|
- ``(1)`` `<https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/lstm_transducer_stateless>`_
|
||||||
|
|
||||||
|
This recipe uses only LibriSpeech during training.
|
||||||
|
|
||||||
|
- ``(2)`` `<https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/lstm_transducer_stateless2>`_
|
||||||
|
|
||||||
|
This recipe uses GigaSpeech + LibriSpeech during training.
|
||||||
|
|
||||||
|
``(1)`` and ``(2)`` use the same model architecture. The only difference is that ``(2)`` supports
|
||||||
|
multi-dataset. Since ``(2)`` uses more data, it has a lower WER than ``(1)`` but it needs
|
||||||
|
more training time.
|
||||||
|
|
||||||
|
We use ``lstm_transducer_stateless2`` as an example below.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
You need to download the `GigaSpeech <https://github.com/SpeechColab/GigaSpeech>`_ dataset
|
||||||
|
to run ``(2)``. If you have only ``LibriSpeech`` dataset available, feel free to use ``(1)``.
|
||||||
|
|
||||||
|
Data preparation
|
||||||
|
----------------
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ cd egs/librispeech/ASR
|
||||||
|
$ ./prepare.sh
|
||||||
|
|
||||||
|
# If you use (1), you can **skip** the following command
|
||||||
|
$ ./prepare_giga_speech.sh
|
||||||
|
|
||||||
|
The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
|
||||||
|
All you need to do is to run it.
|
||||||
|
|
||||||
|
The data preparation contains several stages, you can use the following two
|
||||||
|
options:
|
||||||
|
|
||||||
|
- ``--stage``
|
||||||
|
- ``--stop-stage``
|
||||||
|
|
||||||
|
to control which stage(s) should be run. By default, all stages are executed.
|
||||||
|
|
||||||
|
|
||||||
|
For example,
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ cd egs/librispeech/ASR
|
||||||
|
$ ./prepare.sh --stage 0 --stop-stage 0
|
||||||
|
|
||||||
|
means to run only stage 0.
|
||||||
|
|
||||||
|
To run stage 2 to stage 5, use:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ ./prepare.sh --stage 2 --stop-stage 5
|
||||||
|
|
||||||
|
.. hint::
|
||||||
|
|
||||||
|
If you have pre-downloaded the `LibriSpeech <https://www.openslr.org/12>`_
|
||||||
|
dataset and the `musan <http://www.openslr.org/17/>`_ dataset, say,
|
||||||
|
they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify
|
||||||
|
the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that
|
||||||
|
``./prepare.sh`` won't re-download them.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
All generated files by ``./prepare.sh``, e.g., features, lexicon, etc,
|
||||||
|
are saved in ``./data`` directory.
|
||||||
|
|
||||||
|
We provide the following YouTube video showing how to run ``./prepare.sh``.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
To get the latest news of `next-gen Kaldi <https://github.com/k2-fsa>`_, please subscribe
|
||||||
|
the following YouTube channel by `Nadira Povey <https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_:
|
||||||
|
|
||||||
|
`<https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_
|
||||||
|
|
||||||
|
.. youtube:: ofEIoJL-mGM
|
||||||
|
|
||||||
|
Training
|
||||||
|
--------
|
||||||
|
|
||||||
|
Configurable options
|
||||||
|
~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ cd egs/librispeech/ASR
|
||||||
|
$ ./lstm_transducer_stateless2/train.py --help
|
||||||
|
|
||||||
|
shows you the training options that can be passed from the commandline.
|
||||||
|
The following options are used quite often:
|
||||||
|
|
||||||
|
- ``--full-libri``
|
||||||
|
|
||||||
|
If it's True, the training part uses all the training data, i.e.,
|
||||||
|
960 hours. Otherwise, the training part uses only the subset
|
||||||
|
``train-clean-100``, which has 100 hours of training data.
|
||||||
|
|
||||||
|
.. CAUTION::
|
||||||
|
|
||||||
|
The training set is perturbed by speed with two factors: 0.9 and 1.1.
|
||||||
|
If ``--full-libri`` is True, each epoch actually processes
|
||||||
|
``3x960 == 2880`` hours of data.
|
||||||
|
|
||||||
|
- ``--num-epochs``
|
||||||
|
|
||||||
|
It is the number of epochs to train. For instance,
|
||||||
|
``./lstm_transducer_stateless2/train.py --num-epochs 30`` trains for 30 epochs
|
||||||
|
and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt``
|
||||||
|
in the folder ``./lstm_transducer_stateless2/exp``.
|
||||||
|
|
||||||
|
- ``--start-epoch``
|
||||||
|
|
||||||
|
It's used to resume training.
|
||||||
|
``./lstm_transducer_stateless2/train.py --start-epoch 10`` loads the
|
||||||
|
checkpoint ``./lstm_transducer_stateless2/exp/epoch-9.pt`` and starts
|
||||||
|
training from epoch 10, based on the state from epoch 9.
|
||||||
|
|
||||||
|
- ``--world-size``
|
||||||
|
|
||||||
|
It is used for multi-GPU single-machine DDP training.
|
||||||
|
|
||||||
|
- (a) If it is 1, then no DDP training is used.
|
||||||
|
|
||||||
|
- (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training.
|
||||||
|
|
||||||
|
The following shows some use cases with it.
|
||||||
|
|
||||||
|
**Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and
|
||||||
|
GPU 2 for training. You can do the following:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ cd egs/librispeech/ASR
|
||||||
|
$ export CUDA_VISIBLE_DEVICES="0,2"
|
||||||
|
$ ./lstm_transducer_stateless2/train.py --world-size 2
|
||||||
|
|
||||||
|
**Use case 2**: You have 4 GPUs and you want to use all of them
|
||||||
|
for training. You can do the following:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ cd egs/librispeech/ASR
|
||||||
|
$ ./lstm_transducer_stateless2/train.py --world-size 4
|
||||||
|
|
||||||
|
**Use case 3**: You have 4 GPUs but you only want to use GPU 3
|
||||||
|
for training. You can do the following:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ cd egs/librispeech/ASR
|
||||||
|
$ export CUDA_VISIBLE_DEVICES="3"
|
||||||
|
$ ./lstm_transducer_stateless2/train.py --world-size 1
|
||||||
|
|
||||||
|
.. caution::
|
||||||
|
|
||||||
|
Only multi-GPU single-machine DDP training is implemented at present.
|
||||||
|
Multi-GPU multi-machine DDP training will be added later.
|
||||||
|
|
||||||
|
- ``--max-duration``
|
||||||
|
|
||||||
|
It specifies the number of seconds over all utterances in a
|
||||||
|
batch, before **padding**.
|
||||||
|
If you encounter CUDA OOM, please reduce it.
|
||||||
|
|
||||||
|
.. HINT::
|
||||||
|
|
||||||
|
Due to padding, the number of seconds of all utterances in a
|
||||||
|
batch will usually be larger than ``--max-duration``.
|
||||||
|
|
||||||
|
A larger value for ``--max-duration`` may cause OOM during training,
|
||||||
|
while a smaller value may increase the training time. You have to
|
||||||
|
tune it.
|
||||||
|
|
||||||
|
- ``--giga-prob``
|
||||||
|
|
||||||
|
The probability to select a batch from the ``GigaSpeech`` dataset.
|
||||||
|
Note: It is available only for ``(2)``.
|
||||||
|
|
||||||
|
Pre-configured options
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
There are some training options, e.g., weight decay,
|
||||||
|
number of warmup steps, results dir, etc,
|
||||||
|
that are not passed from the commandline.
|
||||||
|
They are pre-configured by the function ``get_params()`` in
|
||||||
|
`lstm_transducer_stateless2/train.py <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/lstm_transducer_stateless2/train.py>`_
|
||||||
|
|
||||||
|
You don't need to change these pre-configured parameters. If you really need to change
|
||||||
|
them, please modify ``./lstm_transducer_stateless2/train.py`` directly.
|
||||||
|
|
||||||
|
Training logs
|
||||||
|
~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
Training logs and checkpoints are saved in ``lstm_transducer_stateless2/exp``.
|
||||||
|
You will find the following files in that directory:
|
||||||
|
|
||||||
|
- ``epoch-1.pt``, ``epoch-2.pt``, ...
|
||||||
|
|
||||||
|
These are checkpoint files saved at the end of each epoch, containing model
|
||||||
|
``state_dict`` and optimizer ``state_dict``.
|
||||||
|
To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ ./lstm_transducer_stateless2/train.py --start-epoch 11
|
||||||
|
|
||||||
|
- ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ...
|
||||||
|
|
||||||
|
These are checkpoint files saved every ``--save-every-n`` batches,
|
||||||
|
containing model ``state_dict`` and optimizer ``state_dict``.
|
||||||
|
To resume training from some checkpoint, say ``checkpoint-436000``, you can use:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ ./lstm_transducer_stateless2/train.py --start-batch 436000
|
||||||
|
|
||||||
|
- ``tensorboard/``
|
||||||
|
|
||||||
|
This folder contains TensorBoard logs. Training loss, validation loss, learning
|
||||||
|
rate, etc, are recorded in these logs. You can visualize them by:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ cd lstm_transducer_stateless2/exp/tensorboard
|
||||||
|
$ tensorboard dev upload --logdir . --description "LSTM transducer training for LibriSpeech with icefall"
|
||||||
|
|
||||||
|
It will print something like below:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
TensorFlow installation not found - running with reduced feature set.
|
||||||
|
Upload started and will continue reading any new data as it's added to the logdir.
|
||||||
|
|
||||||
|
To stop uploading, press Ctrl-C.
|
||||||
|
|
||||||
|
New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/cj2vtPiwQHKN9Q1tx6PTpg/
|
||||||
|
|
||||||
|
[2022-09-20T15:50:50] Started scanning logdir.
|
||||||
|
Uploading 4468 scalars...
|
||||||
|
[2022-09-20T15:53:02] Total uploaded: 210171 scalars, 0 tensors, 0 binary objects
|
||||||
|
Listening for new data in logdir...
|
||||||
|
|
||||||
|
Note there is a URL in the above output, click it and you will see
|
||||||
|
the following screenshot:
|
||||||
|
|
||||||
|
.. figure:: images/librispeech-lstm-transducer-tensorboard-log.png
|
||||||
|
:width: 600
|
||||||
|
:alt: TensorBoard screenshot
|
||||||
|
:align: center
|
||||||
|
:target: https://tensorboard.dev/experiment/lzGnETjwRxC3yghNMd4kPw/
|
||||||
|
|
||||||
|
TensorBoard screenshot.
|
||||||
|
|
||||||
|
.. hint::
|
||||||
|
|
||||||
|
If you don't have access to google, you can use the following command
|
||||||
|
to view the tensorboard log locally:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
cd lstm_transducer_stateless2/exp/tensorboard
|
||||||
|
tensorboard --logdir . --port 6008
|
||||||
|
|
||||||
|
It will print the following message:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
|
||||||
|
TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit)
|
||||||
|
|
||||||
|
Now start your browser and go to `<http://localhost:6008>`_ to view the tensorboard
|
||||||
|
logs.
|
||||||
|
|
||||||
|
|
||||||
|
- ``log/log-train-xxxx``
|
||||||
|
|
||||||
|
It is the detailed training log in text format, same as the one
|
||||||
|
you saw printed to the console during training.
|
||||||
|
|
||||||
|
Usage example
|
||||||
|
~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
You can use the following command to start the training using 8 GPUs:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||||
|
./lstm_transducer_stateless2/train.py \
|
||||||
|
--world-size 8 \
|
||||||
|
--num-epochs 35 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--full-libri 1 \
|
||||||
|
--exp-dir lstm_transducer_stateless2/exp \
|
||||||
|
--max-duration 500 \
|
||||||
|
--use-fp16 0 \
|
||||||
|
--lr-epochs 10 \
|
||||||
|
--num-workers 2 \
|
||||||
|
--giga-prob 0.9
|
||||||
|
|
||||||
|
Decoding
|
||||||
|
--------
|
||||||
|
|
||||||
|
The decoding part uses checkpoints saved by the training part, so you have
|
||||||
|
to run the training part first.
|
||||||
|
|
||||||
|
.. hint::
|
||||||
|
|
||||||
|
There are two kinds of checkpoints:
|
||||||
|
|
||||||
|
- (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end
|
||||||
|
of each epoch. You can pass ``--epoch`` to
|
||||||
|
``lstm_transducer_stateless2/decode.py`` to use them.
|
||||||
|
|
||||||
|
- (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved
|
||||||
|
every ``--save-every-n`` batches. You can pass ``--iter`` to
|
||||||
|
``lstm_transducer_stateless2/decode.py`` to use them.
|
||||||
|
|
||||||
|
We suggest that you try both types of checkpoints and choose the one
|
||||||
|
that produces the lowest WERs.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ cd egs/librispeech/ASR
|
||||||
|
$ ./lstm_transducer_stateless2/decode.py --help
|
||||||
|
|
||||||
|
shows the options for decoding.
|
||||||
|
|
||||||
|
The following shows two examples:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
for m in greedy_search fast_beam_search modified_beam_search; do
|
||||||
|
for epoch in 17; do
|
||||||
|
for avg in 1 2; do
|
||||||
|
./lstm_transducer_stateless2/decode.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--exp-dir lstm_transducer_stateless2/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--num-encoder-layers 12 \
|
||||||
|
--rnn-hidden-size 1024 \
|
||||||
|
--decoding-method $m \
|
||||||
|
--use-averaged-model True \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8 \
|
||||||
|
--beam-size 4
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
for m in greedy_search fast_beam_search modified_beam_search; do
|
||||||
|
for iter in 474000; do
|
||||||
|
for avg in 8 10 12 14 16 18; do
|
||||||
|
./lstm_transducer_stateless2/decode.py \
|
||||||
|
--iter $iter \
|
||||||
|
--avg $avg \
|
||||||
|
--exp-dir lstm_transducer_stateless2/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--num-encoder-layers 12 \
|
||||||
|
--rnn-hidden-size 1024 \
|
||||||
|
--decoding-method $m \
|
||||||
|
--use-averaged-model True \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8 \
|
||||||
|
--beam-size 4
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
Export models
|
||||||
|
-------------
|
||||||
|
|
||||||
|
`lstm_transducer_stateless2/export.py <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/lstm_transducer_stateless2/export.py>`_ supports to export checkpoints from ``lstm_transducer_stateless2/exp`` in the following ways.
|
||||||
|
|
||||||
|
Export ``model.state_dict()``
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
Checkpoints saved by ``lstm_transducer_stateless2/train.py`` also include
|
||||||
|
``optimizer.state_dict()``. It is useful for resuming training. But after training,
|
||||||
|
we are interested only in ``model.state_dict()``. You can use the following
|
||||||
|
command to extract ``model.state_dict()``.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
# Assume that --iter 468000 --avg 16 produces the smallest WER
|
||||||
|
# (You can get such information after running ./lstm_transducer_stateless2/decode.py)
|
||||||
|
|
||||||
|
iter=468000
|
||||||
|
avg=16
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/export.py \
|
||||||
|
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--iter $iter \
|
||||||
|
--avg $avg
|
||||||
|
|
||||||
|
It will generate a file ``./lstm_transducer_stateless2/exp/pretrained.pt``.
|
||||||
|
|
||||||
|
.. hint::
|
||||||
|
|
||||||
|
To use the generated ``pretrained.pt`` for ``lstm_transducer_stateless2/decode.py``,
|
||||||
|
you can run:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
cd lstm_transducer_stateless2/exp
|
||||||
|
ln -s pretrained epoch-9999.pt
|
||||||
|
|
||||||
|
And then pass `--epoch 9999 --avg 1 --use-averaged-model 0` to
|
||||||
|
``./lstm_transducer_stateless2/decode.py``.
|
||||||
|
|
||||||
|
To use the exported model with ``./lstm_transducer_stateless2/pretrained.py``, you
|
||||||
|
can run:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/pretrained.py \
|
||||||
|
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
--method greedy_search \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
Export model using ``torch.jit.trace()``
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
iter=468000
|
||||||
|
avg=16
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/export.py \
|
||||||
|
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--iter $iter \
|
||||||
|
--avg $avg \
|
||||||
|
--jit-trace 1
|
||||||
|
|
||||||
|
It will generate 3 files:
|
||||||
|
|
||||||
|
- ``./lstm_transducer_stateless2/exp/encoder_jit_trace.pt``
|
||||||
|
- ``./lstm_transducer_stateless2/exp/decoder_jit_trace.pt``
|
||||||
|
- ``./lstm_transducer_stateless2/exp/joiner_jit_trace.pt``
|
||||||
|
|
||||||
|
To use the generated files with ``./lstm_transducer_stateless2/jit_pretrained``:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/jit_pretrained.py \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
--encoder-model-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace.pt \
|
||||||
|
--decoder-model-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace.pt \
|
||||||
|
--joiner-model-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace.pt \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
Export model for ncnn
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
We support exporting pretrained LSTM transducer models to
|
||||||
|
`ncnn <https://github.com/tencent/ncnn>`_ using
|
||||||
|
`pnnx <https://github.com/Tencent/ncnn/tree/master/tools/pnnx>`_.
|
||||||
|
|
||||||
|
First, let us install a modified version of ``ncnn``:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
git clone https://github.com/csukuangfj/ncnn
|
||||||
|
cd ncnn
|
||||||
|
git submodule update --recursive --init
|
||||||
|
python3 setup.py bdist_wheel
|
||||||
|
ls -lh dist/
|
||||||
|
pip install ./dist/*.whl
|
||||||
|
|
||||||
|
# now build pnnx
|
||||||
|
cd tools/pnnx
|
||||||
|
mkdir build
|
||||||
|
cd build
|
||||||
|
make -j4
|
||||||
|
export PATH=$PWD/src:$PATH
|
||||||
|
|
||||||
|
./src/pnnx
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
We assume that you have added the path to the binary ``pnnx`` to the
|
||||||
|
environment variable ``PATH``.
|
||||||
|
|
||||||
|
Second, let us export the model using ``torch.jit.trace()`` that is suitable
|
||||||
|
for ``pnnx``:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
iter=468000
|
||||||
|
avg=16
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/export.py \
|
||||||
|
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--iter $iter \
|
||||||
|
--avg $avg \
|
||||||
|
--pnnx 1
|
||||||
|
|
||||||
|
It will generate 3 files:
|
||||||
|
|
||||||
|
- ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt``
|
||||||
|
- ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt``
|
||||||
|
- ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt``
|
||||||
|
|
||||||
|
Third, convert torchscript model to ``ncnn`` format:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
pnnx ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt
|
||||||
|
pnnx ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt
|
||||||
|
pnnx ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt
|
||||||
|
|
||||||
|
It will generate the following files:
|
||||||
|
|
||||||
|
- ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param``
|
||||||
|
- ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin``
|
||||||
|
- ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param``
|
||||||
|
- ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin``
|
||||||
|
- ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param``
|
||||||
|
- ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin``
|
||||||
|
|
||||||
|
To use the above generate files, run:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/ncnn-decode.py \
|
||||||
|
--bpe-model-filename ./data/lang_bpe_500/bpe.model \
|
||||||
|
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \
|
||||||
|
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \
|
||||||
|
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \
|
||||||
|
--decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin \
|
||||||
|
--joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param \
|
||||||
|
--joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin \
|
||||||
|
/path/to/foo.wav
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/streaming-ncnn-decode.py \
|
||||||
|
--bpe-model-filename ./data/lang_bpe_500/bpe.model \
|
||||||
|
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \
|
||||||
|
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \
|
||||||
|
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \
|
||||||
|
--decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin \
|
||||||
|
--joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param \
|
||||||
|
--joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin \
|
||||||
|
/path/to/foo.wav
|
||||||
|
|
||||||
|
To use the above generated files in C++, please see
|
||||||
|
`<https://github.com/k2-fsa/sherpa-ncnn>`_
|
||||||
|
|
||||||
|
It is able to generate a static linked library that can be run on Linux, Windows,
|
||||||
|
macOS, Raspberry Pi, etc.
|
||||||
|
|
||||||
|
Download pretrained models
|
||||||
|
--------------------------
|
||||||
|
|
||||||
|
If you don't want to train from scratch, you can download the pretrained models
|
||||||
|
by visiting the following links:
|
||||||
|
|
||||||
|
- `<https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03>`_
|
||||||
|
|
||||||
|
- `<https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18>`_
|
||||||
|
|
||||||
|
See `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>`_
|
||||||
|
for the details of the above pretrained models
|
||||||
|
|
||||||
|
You can find more usages of the pretrained models in
|
||||||
|
`<https://k2-fsa.github.io/sherpa/python/streaming_asr/lstm/index.html>`_
|
@ -116,6 +116,8 @@ class RNN(EncoderInterface):
|
|||||||
Period of auxiliary layers used for random combiner during training.
|
Period of auxiliary layers used for random combiner during training.
|
||||||
If set to 0, will not use the random combiner (Default).
|
If set to 0, will not use the random combiner (Default).
|
||||||
You can set a positive integer to use the random combiner, e.g., 3.
|
You can set a positive integer to use the random combiner, e.g., 3.
|
||||||
|
is_pnnx:
|
||||||
|
True to make this class exportable via PNNX.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -129,6 +131,7 @@ class RNN(EncoderInterface):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
layer_dropout: float = 0.075,
|
layer_dropout: float = 0.075,
|
||||||
aux_layer_period: int = 0,
|
aux_layer_period: int = 0,
|
||||||
|
is_pnnx: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(RNN, self).__init__()
|
super(RNN, self).__init__()
|
||||||
|
|
||||||
@ -142,7 +145,13 @@ class RNN(EncoderInterface):
|
|||||||
# That is, it does two things simultaneously:
|
# That is, it does two things simultaneously:
|
||||||
# (1) subsampling: T -> T//subsampling_factor
|
# (1) subsampling: T -> T//subsampling_factor
|
||||||
# (2) embedding: num_features -> d_model
|
# (2) embedding: num_features -> d_model
|
||||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
self.encoder_embed = Conv2dSubsampling(
|
||||||
|
num_features,
|
||||||
|
d_model,
|
||||||
|
is_pnnx=is_pnnx,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.is_pnnx = is_pnnx
|
||||||
|
|
||||||
self.num_encoder_layers = num_encoder_layers
|
self.num_encoder_layers = num_encoder_layers
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
@ -209,7 +218,13 @@ class RNN(EncoderInterface):
|
|||||||
# lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning
|
# lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning
|
||||||
#
|
#
|
||||||
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
||||||
lengths = (((x_lens - 3) >> 1) - 1) >> 1
|
if not self.is_pnnx:
|
||||||
|
lengths = (((x_lens - 3) >> 1) - 1) >> 1
|
||||||
|
else:
|
||||||
|
lengths1 = torch.floor((x_lens - 3) / 2)
|
||||||
|
lengths = torch.floor((lengths1 - 1) / 2)
|
||||||
|
lengths = lengths.to(x_lens)
|
||||||
|
|
||||||
if not torch.jit.is_tracing():
|
if not torch.jit.is_tracing():
|
||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
|
|
||||||
@ -359,7 +374,7 @@ class RNNEncoderLayer(nn.Module):
|
|||||||
# for cell state
|
# for cell state
|
||||||
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
|
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
|
||||||
src_lstm, new_states = self.lstm(src, states)
|
src_lstm, new_states = self.lstm(src, states)
|
||||||
src = src + self.dropout(src_lstm)
|
src = self.dropout(src_lstm) + src
|
||||||
|
|
||||||
# feed forward module
|
# feed forward module
|
||||||
src = src + self.dropout(self.feed_forward(src))
|
src = src + self.dropout(self.feed_forward(src))
|
||||||
@ -505,6 +520,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
layer1_channels: int = 8,
|
layer1_channels: int = 8,
|
||||||
layer2_channels: int = 32,
|
layer2_channels: int = 32,
|
||||||
layer3_channels: int = 128,
|
layer3_channels: int = 128,
|
||||||
|
is_pnnx: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -517,6 +533,9 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
Number of channels in layer1
|
Number of channels in layer1
|
||||||
layer1_channels:
|
layer1_channels:
|
||||||
Number of channels in layer2
|
Number of channels in layer2
|
||||||
|
is_pnnx:
|
||||||
|
True if we are converting the model to PNNX format.
|
||||||
|
False otherwise.
|
||||||
"""
|
"""
|
||||||
assert in_channels >= 9
|
assert in_channels >= 9
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -559,6 +578,10 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ncnn supports only batch size == 1
|
||||||
|
self.is_pnnx = is_pnnx
|
||||||
|
self.conv_out_dim = self.out.weight.shape[1]
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Subsample x.
|
"""Subsample x.
|
||||||
|
|
||||||
@ -572,9 +595,15 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
# On entry, x is (N, T, idim)
|
# On entry, x is (N, T, idim)
|
||||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
|
|
||||||
b, c, t, f = x.size()
|
if torch.jit.is_tracing() and self.is_pnnx:
|
||||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim)
|
||||||
|
x = self.out(x)
|
||||||
|
else:
|
||||||
|
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
|
||||||
|
b, c, t, f = x.size()
|
||||||
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||||
|
|
||||||
# Now x is of shape (N, ((T-3)//2-1))//2, odim)
|
# Now x is of shape (N, ((T-3)//2-1))//2, odim)
|
||||||
x = self.out_norm(x)
|
x = self.out_norm(x)
|
||||||
x = self.out_balancer(x)
|
x = self.out_balancer(x)
|
||||||
|
@ -169,6 +169,18 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--pnnx",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True to save a model after applying torch.jit.trace for later
|
||||||
|
converting to PNNX. It will generate 3 files:
|
||||||
|
- encoder_jit_trace-pnnx.pt
|
||||||
|
- decoder_jit_trace-pnnx.pt
|
||||||
|
- joiner_jit_trace-pnnx.pt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -277,6 +289,10 @@ def main():
|
|||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
if params.pnnx:
|
||||||
|
params.is_pnnx = params.pnnx
|
||||||
|
logging.info("For PNNX")
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params, enable_giga=False)
|
model = get_transducer_model(params, enable_giga=False)
|
||||||
|
|
||||||
@ -371,7 +387,18 @@ def main():
|
|||||||
model.to("cpu")
|
model.to("cpu")
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit_trace is True:
|
if params.pnnx:
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
logging.info("Using torch.jit.trace()")
|
||||||
|
encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
|
||||||
|
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||||
|
|
||||||
|
decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
|
||||||
|
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||||
|
|
||||||
|
joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
|
||||||
|
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||||
|
elif params.jit_trace is True:
|
||||||
convert_scaled_to_non_scaled(model, inplace=True)
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
logging.info("Using torch.jit.trace()")
|
logging.info("Using torch.jit.trace()")
|
||||||
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
||||||
|
295
egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py
Executable file
295
egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py
Executable file
@ -0,0 +1,295 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# flake8: noqa
|
||||||
|
#
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
./lstm_transducer_stateless2/ncnn-decode.py \
|
||||||
|
--bpe-model-filename ./data/lang_bpe_500/bpe.model \
|
||||||
|
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
|
||||||
|
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
|
||||||
|
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
|
||||||
|
--decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
|
||||||
|
--joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
|
||||||
|
--joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
|
||||||
|
./test_wavs/1089-134686-0001.wav
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import kaldifeat
|
||||||
|
import ncnn
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model-filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to bpe.model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-param-filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to encoder.ncnn.param",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-bin-filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to encoder.ncnn.bin",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-param-filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to decoder.ncnn.param",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-bin-filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to decoder.ncnn.bin",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-param-filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to joiner.ncnn.param",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-bin-filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to joiner.ncnn.bin",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to foo.wav",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
class Model:
|
||||||
|
def __init__(self, args):
|
||||||
|
self.init_encoder(args)
|
||||||
|
self.init_decoder(args)
|
||||||
|
self.init_joiner(args)
|
||||||
|
|
||||||
|
def init_encoder(self, args):
|
||||||
|
encoder_net = ncnn.Net()
|
||||||
|
encoder_net.opt.use_packing_layout = False
|
||||||
|
encoder_net.opt.use_fp16_storage = False
|
||||||
|
encoder_param = args.encoder_param_filename
|
||||||
|
encoder_model = args.encoder_bin_filename
|
||||||
|
|
||||||
|
encoder_net.load_param(encoder_param)
|
||||||
|
encoder_net.load_model(encoder_model)
|
||||||
|
|
||||||
|
self.encoder_net = encoder_net
|
||||||
|
|
||||||
|
def init_decoder(self, args):
|
||||||
|
decoder_param = args.decoder_param_filename
|
||||||
|
decoder_model = args.decoder_bin_filename
|
||||||
|
|
||||||
|
decoder_net = ncnn.Net()
|
||||||
|
decoder_net.opt.use_packing_layout = False
|
||||||
|
|
||||||
|
decoder_net.load_param(decoder_param)
|
||||||
|
decoder_net.load_model(decoder_model)
|
||||||
|
|
||||||
|
self.decoder_net = decoder_net
|
||||||
|
|
||||||
|
def init_joiner(self, args):
|
||||||
|
joiner_param = args.joiner_param_filename
|
||||||
|
joiner_model = args.joiner_bin_filename
|
||||||
|
joiner_net = ncnn.Net()
|
||||||
|
joiner_net.opt.use_packing_layout = False
|
||||||
|
joiner_net.load_param(joiner_param)
|
||||||
|
joiner_net.load_model(joiner_model)
|
||||||
|
|
||||||
|
self.joiner_net = joiner_net
|
||||||
|
|
||||||
|
def run_encoder(self, x, states):
|
||||||
|
with self.encoder_net.create_extractor() as ex:
|
||||||
|
ex.set_num_threads(10)
|
||||||
|
ex.input("in0", ncnn.Mat(x.numpy()).clone())
|
||||||
|
x_lens = torch.tensor([x.size(0)], dtype=torch.float32)
|
||||||
|
ex.input("in1", ncnn.Mat(x_lens.numpy()).clone())
|
||||||
|
ex.input("in2", ncnn.Mat(states[0].numpy()).clone())
|
||||||
|
ex.input("in3", ncnn.Mat(states[1].numpy()).clone())
|
||||||
|
|
||||||
|
ret, ncnn_out0 = ex.extract("out0")
|
||||||
|
assert ret == 0, ret
|
||||||
|
|
||||||
|
ret, ncnn_out1 = ex.extract("out1")
|
||||||
|
assert ret == 0, ret
|
||||||
|
|
||||||
|
ret, ncnn_out2 = ex.extract("out2")
|
||||||
|
assert ret == 0, ret
|
||||||
|
|
||||||
|
ret, ncnn_out3 = ex.extract("out3")
|
||||||
|
assert ret == 0, ret
|
||||||
|
|
||||||
|
encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
|
||||||
|
encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(
|
||||||
|
torch.int32
|
||||||
|
)
|
||||||
|
hx = torch.from_numpy(ncnn_out2.numpy()).clone()
|
||||||
|
cx = torch.from_numpy(ncnn_out3.numpy()).clone()
|
||||||
|
return encoder_out, encoder_out_lens, hx, cx
|
||||||
|
|
||||||
|
def run_decoder(self, decoder_input):
|
||||||
|
assert decoder_input.dtype == torch.int32
|
||||||
|
|
||||||
|
with self.decoder_net.create_extractor() as ex:
|
||||||
|
ex.set_num_threads(10)
|
||||||
|
ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
|
||||||
|
ret, ncnn_out0 = ex.extract("out0")
|
||||||
|
assert ret == 0, ret
|
||||||
|
decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
|
||||||
|
return decoder_out
|
||||||
|
|
||||||
|
def run_joiner(self, encoder_out, decoder_out):
|
||||||
|
with self.joiner_net.create_extractor() as ex:
|
||||||
|
ex.set_num_threads(10)
|
||||||
|
ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
|
||||||
|
ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
|
||||||
|
ret, ncnn_out0 = ex.extract("out0")
|
||||||
|
assert ret == 0, ret
|
||||||
|
joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone()
|
||||||
|
return joiner_out
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert sample_rate == expected_sample_rate, (
|
||||||
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search(model: Model, encoder_out: torch.Tensor):
|
||||||
|
assert encoder_out.ndim == 2
|
||||||
|
T = encoder_out.size(0)
|
||||||
|
|
||||||
|
context_size = 2
|
||||||
|
blank_id = 0 # hard-code to 0
|
||||||
|
hyp = [blank_id] * context_size
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.run_decoder(decoder_input).squeeze(0)
|
||||||
|
# print(decoder_out.shape) # (512,)
|
||||||
|
|
||||||
|
for t in range(T):
|
||||||
|
encoder_out_t = encoder_out[t]
|
||||||
|
joiner_out = model.run_joiner(encoder_out_t, decoder_out)
|
||||||
|
# print(joiner_out.shape) # [500]
|
||||||
|
y = joiner_out.argmax(dim=0).tolist()
|
||||||
|
if y != blank_id:
|
||||||
|
hyp.append(y)
|
||||||
|
decoder_input = hyp[-context_size:]
|
||||||
|
decoder_input = torch.tensor(decoder_input, dtype=torch.int32)
|
||||||
|
decoder_out = model.run_decoder(decoder_input).squeeze(0)
|
||||||
|
return hyp[context_size:]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
model = Model(args)
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(args.bpe_model_filename)
|
||||||
|
|
||||||
|
sound_file = args.sound_filename
|
||||||
|
|
||||||
|
sample_rate = 16000
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = "cpu"
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = sample_rate
|
||||||
|
opts.mel_opts.num_bins = 80
|
||||||
|
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {sound_file}")
|
||||||
|
wave_samples = read_sound_files(
|
||||||
|
filenames=[sound_file],
|
||||||
|
expected_sample_rate=sample_rate,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
features = fbank(wave_samples)
|
||||||
|
|
||||||
|
num_encoder_layers = 12
|
||||||
|
d_model = 512
|
||||||
|
rnn_hidden_size = 1024
|
||||||
|
|
||||||
|
states = (
|
||||||
|
torch.zeros(num_encoder_layers, d_model),
|
||||||
|
torch.zeros(
|
||||||
|
num_encoder_layers,
|
||||||
|
rnn_hidden_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens, hx, cx = model.run_encoder(features, states)
|
||||||
|
hyp = greedy_search(model, encoder_out)
|
||||||
|
logging.info(sound_file)
|
||||||
|
logging.info(sp.decode(hyp))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
main()
|
353
egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py
Executable file
353
egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py
Executable file
@ -0,0 +1,353 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# flake8: noqa
|
||||||
|
#
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import ncnn
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model-filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to bpe.model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-param-filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to encoder.ncnn.param",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-bin-filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to encoder.ncnn.bin",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-param-filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to decoder.ncnn.param",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-bin-filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to decoder.ncnn.bin",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-param-filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to joiner.ncnn.param",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-bin-filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to joiner.ncnn.bin",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to foo.wav",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
class Model:
|
||||||
|
def __init__(self, args):
|
||||||
|
self.init_encoder(args)
|
||||||
|
self.init_decoder(args)
|
||||||
|
self.init_joiner(args)
|
||||||
|
|
||||||
|
def init_encoder(self, args):
|
||||||
|
encoder_net = ncnn.Net()
|
||||||
|
encoder_net.opt.use_packing_layout = False
|
||||||
|
encoder_net.opt.use_fp16_storage = False
|
||||||
|
encoder_param = args.encoder_param_filename
|
||||||
|
encoder_model = args.encoder_bin_filename
|
||||||
|
|
||||||
|
encoder_net.load_param(encoder_param)
|
||||||
|
encoder_net.load_model(encoder_model)
|
||||||
|
|
||||||
|
self.encoder_net = encoder_net
|
||||||
|
|
||||||
|
def init_decoder(self, args):
|
||||||
|
decoder_param = args.decoder_param_filename
|
||||||
|
decoder_model = args.decoder_bin_filename
|
||||||
|
|
||||||
|
decoder_net = ncnn.Net()
|
||||||
|
decoder_net.opt.use_packing_layout = False
|
||||||
|
|
||||||
|
decoder_net.load_param(decoder_param)
|
||||||
|
decoder_net.load_model(decoder_model)
|
||||||
|
|
||||||
|
self.decoder_net = decoder_net
|
||||||
|
|
||||||
|
def init_joiner(self, args):
|
||||||
|
joiner_param = args.joiner_param_filename
|
||||||
|
joiner_model = args.joiner_bin_filename
|
||||||
|
joiner_net = ncnn.Net()
|
||||||
|
joiner_net.opt.use_packing_layout = False
|
||||||
|
joiner_net.load_param(joiner_param)
|
||||||
|
joiner_net.load_model(joiner_model)
|
||||||
|
|
||||||
|
self.joiner_net = joiner_net
|
||||||
|
|
||||||
|
def run_encoder(self, x, states):
|
||||||
|
with self.encoder_net.create_extractor() as ex:
|
||||||
|
# ex.set_num_threads(10)
|
||||||
|
ex.input("in0", ncnn.Mat(x.numpy()).clone())
|
||||||
|
x_lens = torch.tensor([x.size(0)], dtype=torch.float32)
|
||||||
|
ex.input("in1", ncnn.Mat(x_lens.numpy()).clone())
|
||||||
|
ex.input("in2", ncnn.Mat(states[0].numpy()).clone())
|
||||||
|
ex.input("in3", ncnn.Mat(states[1].numpy()).clone())
|
||||||
|
|
||||||
|
ret, ncnn_out0 = ex.extract("out0")
|
||||||
|
assert ret == 0, ret
|
||||||
|
|
||||||
|
ret, ncnn_out1 = ex.extract("out1")
|
||||||
|
assert ret == 0, ret
|
||||||
|
|
||||||
|
ret, ncnn_out2 = ex.extract("out2")
|
||||||
|
assert ret == 0, ret
|
||||||
|
|
||||||
|
ret, ncnn_out3 = ex.extract("out3")
|
||||||
|
assert ret == 0, ret
|
||||||
|
|
||||||
|
encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
|
||||||
|
encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(
|
||||||
|
torch.int32
|
||||||
|
)
|
||||||
|
hx = torch.from_numpy(ncnn_out2.numpy()).clone()
|
||||||
|
cx = torch.from_numpy(ncnn_out3.numpy()).clone()
|
||||||
|
return encoder_out, encoder_out_lens, hx, cx
|
||||||
|
|
||||||
|
def run_decoder(self, decoder_input):
|
||||||
|
assert decoder_input.dtype == torch.int32
|
||||||
|
|
||||||
|
with self.decoder_net.create_extractor() as ex:
|
||||||
|
# ex.set_num_threads(10)
|
||||||
|
ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
|
||||||
|
ret, ncnn_out0 = ex.extract("out0")
|
||||||
|
assert ret == 0, ret
|
||||||
|
decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
|
||||||
|
return decoder_out
|
||||||
|
|
||||||
|
def run_joiner(self, encoder_out, decoder_out):
|
||||||
|
with self.joiner_net.create_extractor() as ex:
|
||||||
|
# ex.set_num_threads(10)
|
||||||
|
ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
|
||||||
|
ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
|
||||||
|
ret, ncnn_out0 = ex.extract("out0")
|
||||||
|
assert ret == 0, ret
|
||||||
|
joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone()
|
||||||
|
return joiner_out
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert sample_rate == expected_sample_rate, (
|
||||||
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def create_streaming_feature_extractor() -> OnlineFeature:
|
||||||
|
"""Create a CPU streaming feature extractor.
|
||||||
|
|
||||||
|
At present, we assume it returns a fbank feature extractor with
|
||||||
|
fixed options. In the future, we will support passing in the options
|
||||||
|
from outside.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a CPU streaming feature extractor.
|
||||||
|
"""
|
||||||
|
opts = FbankOptions()
|
||||||
|
opts.device = "cpu"
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = 16000
|
||||||
|
opts.mel_opts.num_bins = 80
|
||||||
|
return OnlineFbank(opts)
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search(
|
||||||
|
model: Model,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: Optional[torch.Tensor] = None,
|
||||||
|
hyp: Optional[List[int]] = None,
|
||||||
|
):
|
||||||
|
assert encoder_out.ndim == 1
|
||||||
|
context_size = 2
|
||||||
|
blank_id = 0
|
||||||
|
|
||||||
|
if decoder_out is None:
|
||||||
|
assert hyp is None, hyp
|
||||||
|
hyp = [blank_id] * context_size
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
hyp, dtype=torch.int32
|
||||||
|
) # (1, context_size)
|
||||||
|
decoder_out = model.run_decoder(decoder_input).squeeze(0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert decoder_out.ndim == 1
|
||||||
|
assert hyp is not None, hyp
|
||||||
|
|
||||||
|
joiner_out = model.run_joiner(encoder_out, decoder_out)
|
||||||
|
y = joiner_out.argmax(dim=0).tolist()
|
||||||
|
if y != blank_id:
|
||||||
|
hyp.append(y)
|
||||||
|
decoder_input = hyp[-context_size:]
|
||||||
|
decoder_input = torch.tensor(decoder_input, dtype=torch.int32)
|
||||||
|
decoder_out = model.run_decoder(decoder_input).squeeze(0)
|
||||||
|
|
||||||
|
return hyp, decoder_out
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
model = Model(args)
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(args.bpe_model_filename)
|
||||||
|
|
||||||
|
sound_file = args.sound_filename
|
||||||
|
|
||||||
|
sample_rate = 16000
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
online_fbank = create_streaming_feature_extractor()
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {sound_file}")
|
||||||
|
wave_samples = read_sound_files(
|
||||||
|
filenames=[sound_file],
|
||||||
|
expected_sample_rate=sample_rate,
|
||||||
|
)[0]
|
||||||
|
logging.info(wave_samples.shape)
|
||||||
|
|
||||||
|
num_encoder_layers = 12
|
||||||
|
batch_size = 1
|
||||||
|
d_model = 512
|
||||||
|
rnn_hidden_size = 1024
|
||||||
|
|
||||||
|
states = (
|
||||||
|
torch.zeros(num_encoder_layers, batch_size, d_model),
|
||||||
|
torch.zeros(
|
||||||
|
num_encoder_layers,
|
||||||
|
batch_size,
|
||||||
|
rnn_hidden_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
hyp = None
|
||||||
|
decoder_out = None
|
||||||
|
|
||||||
|
num_processed_frames = 0
|
||||||
|
segment = 9
|
||||||
|
offset = 4
|
||||||
|
|
||||||
|
chunk = 3200 # 0.2 second
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
while start < wave_samples.numel():
|
||||||
|
end = min(start + chunk, wave_samples.numel())
|
||||||
|
samples = wave_samples[start:end]
|
||||||
|
start += chunk
|
||||||
|
|
||||||
|
online_fbank.accept_waveform(
|
||||||
|
sampling_rate=sample_rate,
|
||||||
|
waveform=samples,
|
||||||
|
)
|
||||||
|
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
||||||
|
frames = []
|
||||||
|
for i in range(segment):
|
||||||
|
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||||
|
num_processed_frames += offset
|
||||||
|
frames = torch.cat(frames, dim=0)
|
||||||
|
encoder_out, encoder_out_lens, hx, cx = model.run_encoder(
|
||||||
|
frames, states
|
||||||
|
)
|
||||||
|
states = (hx, cx)
|
||||||
|
hyp, decoder_out = greedy_search(
|
||||||
|
model, encoder_out.squeeze(0), decoder_out, hyp
|
||||||
|
)
|
||||||
|
online_fbank.accept_waveform(
|
||||||
|
sampling_rate=sample_rate, waveform=torch.zeros(8000, dtype=torch.int32)
|
||||||
|
)
|
||||||
|
|
||||||
|
online_fbank.input_finished()
|
||||||
|
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
||||||
|
frames = []
|
||||||
|
for i in range(segment):
|
||||||
|
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||||
|
num_processed_frames += offset
|
||||||
|
frames = torch.cat(frames, dim=0)
|
||||||
|
encoder_out, encoder_out_lens, hx, cx = model.run_encoder(
|
||||||
|
frames, states
|
||||||
|
)
|
||||||
|
states = (hx, cx)
|
||||||
|
hyp, decoder_out = greedy_search(
|
||||||
|
model, encoder_out.squeeze(0), decoder_out, hyp
|
||||||
|
)
|
||||||
|
|
||||||
|
context_size = 2
|
||||||
|
|
||||||
|
logging.info(sound_file)
|
||||||
|
logging.info(sp.decode(hyp[context_size:]))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
main()
|
@ -406,6 +406,8 @@ def get_params() -> AttributeDict:
|
|||||||
"decoder_dim": 512,
|
"decoder_dim": 512,
|
||||||
# parameters for joiner
|
# parameters for joiner
|
||||||
"joiner_dim": 512,
|
"joiner_dim": 512,
|
||||||
|
# True to generate a model that can be exported via PNNX
|
||||||
|
"is_pnnx": False,
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
@ -424,6 +426,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
aux_layer_period=params.aux_layer_period,
|
aux_layer_period=params.aux_layer_period,
|
||||||
|
is_pnnx=params.is_pnnx,
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ from typing import List
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling import (
|
from scaling import (
|
||||||
|
BasicNorm,
|
||||||
ScaledConv1d,
|
ScaledConv1d,
|
||||||
ScaledConv2d,
|
ScaledConv2d,
|
||||||
ScaledEmbedding,
|
ScaledEmbedding,
|
||||||
@ -38,6 +39,29 @@ from scaling import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NonScaledNorm(nn.Module):
|
||||||
|
"""See BasicNorm for doc"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_channels: int,
|
||||||
|
eps_exp: float,
|
||||||
|
channel_dim: int = -1, # CAUTION: see documentation.
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.channel_dim = channel_dim
|
||||||
|
self.eps_exp = eps_exp
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if not torch.jit.is_tracing():
|
||||||
|
assert x.shape[self.channel_dim] == self.num_channels
|
||||||
|
scales = (
|
||||||
|
torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp
|
||||||
|
).pow(-0.5)
|
||||||
|
return x * scales
|
||||||
|
|
||||||
|
|
||||||
def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
|
def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
|
||||||
"""Convert an instance of ScaledLinear to nn.Linear.
|
"""Convert an instance of ScaledLinear to nn.Linear.
|
||||||
|
|
||||||
@ -174,6 +198,16 @@ def scaled_embedding_to_embedding(
|
|||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
|
||||||
|
assert isinstance(basic_norm, BasicNorm), type(BasicNorm)
|
||||||
|
norm = NonScaledNorm(
|
||||||
|
num_channels=basic_norm.num_channels,
|
||||||
|
eps_exp=basic_norm.eps.data.exp().item(),
|
||||||
|
channel_dim=basic_norm.channel_dim,
|
||||||
|
)
|
||||||
|
return norm
|
||||||
|
|
||||||
|
|
||||||
def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM:
|
def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM:
|
||||||
"""Convert an instance of ScaledLSTM to nn.LSTM.
|
"""Convert an instance of ScaledLSTM to nn.LSTM.
|
||||||
|
|
||||||
@ -256,6 +290,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
|
|||||||
d[name] = scaled_conv2d_to_conv2d(m)
|
d[name] = scaled_conv2d_to_conv2d(m)
|
||||||
elif isinstance(m, ScaledEmbedding):
|
elif isinstance(m, ScaledEmbedding):
|
||||||
d[name] = scaled_embedding_to_embedding(m)
|
d[name] = scaled_embedding_to_embedding(m)
|
||||||
|
elif isinstance(m, BasicNorm):
|
||||||
|
d[name] = convert_basic_norm(m)
|
||||||
elif isinstance(m, ScaledLSTM):
|
elif isinstance(m, ScaledLSTM):
|
||||||
d[name] = scaled_lstm_to_lstm(m)
|
d[name] = scaled_lstm_to_lstm(m)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user