support exporting to ncnn format via PNNX (#571)

This commit is contained in:
Fangjun Kuang 2022-09-20 22:52:49 +08:00 committed by GitHub
parent 436942211c
commit 099cd3a215
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1674 additions and 7 deletions

View File

@ -0,0 +1,160 @@
#!/usr/bin/env bash
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
ln -s pretrained-iter-468000-avg-16.pt pretrained.pt
ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
popd
log "Install ncnn and pnnx"
# We are using a modified ncnn here. Will try to merge it to the official repo
# of ncnn
git clone https://github.com/csukuangfj/ncnn
pushd ncnn
git submodule init
git submodule update python/pybind11
python3 setup.py bdist_wheel
ls -lh dist/
pip install dist/*.whl
cd tools/pnnx
mkdir build
cd build
cmake ..
make -j4 pnnx
./src/pnnx || echo "pass"
popd
log "Test exporting to pnnx format"
./lstm_transducer_stateless2/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
--pnnx 1
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
./lstm_transducer_stateless2/ncnn-decode.py \
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
$repo/test_wavs/1089-134686-0001.wav
./lstm_transducer_stateless2/streaming-ncnn-decode.py \
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
$repo/test_wavs/1089-134686-0001.wav
log "Test exporting with torch.jit.trace()"
./lstm_transducer_stateless2/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
--jit-trace 1
log "Decode with models exported by torch.jit.trace()"
./lstm_transducer_stateless2/jit_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo/exp/encoder_jit_trace.pt \
--decoder-model-filename $repo/exp/decoder_jit_trace.pt \
--joiner-model-filename $repo/exp/joiner_jit_trace.pt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"
./lstm_transducer_stateless2/pretrained.py \
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
for method in modified_beam_search beam_search fast_beam_search; do
log "$method"
./lstm_transducer_stateless2/pretrained.py \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"ncnn" ]]; then
mkdir -p lstm_transducer_stateless2/exp
ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh lstm_transducer_stateless2/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method"
./lstm_transducer_stateless2/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--use-averaged-model 0 \
--max-duration $max_duration \
--exp-dir lstm_transducer_stateless2/exp
done
rm lstm_transducer_stateless2/exp/*.pt
fi

View File

@ -0,0 +1,136 @@
name: run-librispeech-lstm-transducer-2022-09-03
on:
push:
branches:
- master
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
jobs:
run_librispeech_pruned_transducer_stateless3_2022_05_13:
if: github.event.label.name == 'ncnn' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.8]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
pip install --no-binary protobuf protobuf
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/install-kaldifeat.sh
- name: Cache LibriSpeech test-clean and test-other datasets
id: libri-test-clean-and-test-other-data
uses: actions/cache@v2
with:
path: |
~/tmp/download
key: cache-libri-test-clean-and-test-other
- name: Download LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
- name: Prepare manifests for LibriSpeech test-clean and test-other
shell: bash
run: |
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
- name: Cache LibriSpeech test-clean and test-other fbank features
id: libri-test-clean-and-test-other-fbank
uses: actions/cache@v2
with:
path: |
~/tmp/fbank-libri
key: cache-libri-fbank-test-clean-and-test-other-v2
- name: Compute fbank for LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
- name: Display decoding results for lstm_transducer_stateless2
if: github.event_name == 'schedule' || github.event.label.name == 'ncnn'
shell: bash
run: |
cd egs/librispeech/ASR
tree lstm_transducer_stateless2/exp
cd lstm_transducer_stateless2/exp
echo "===greedy search==="
find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search==="
find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for lstm_transducer_stateless2
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'ncnn'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03
path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/

2
.gitignore vendored
View File

@ -11,3 +11,5 @@ log
*.bak
*-bak
*bak.py
*.param
*.bin

Binary file not shown.

After

Width:  |  Height:  |  Size: 413 KiB

View File

@ -6,3 +6,4 @@ LibriSpeech
tdnn_lstm_ctc
conformer_ctc
lstm_pruned_stateless_transducer

View File

@ -0,0 +1,625 @@
Transducer
==========
.. hint::
Please scroll down to the bottom of this page to find download links
for pretrained models if you don't want to train a model from scratch.
This tutorial shows you how to train a transducer model
with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
We use pruned RNN-T to compute the loss.
.. note::
You can find the paper about pruned RNN-T at the following address:
`<https://arxiv.org/abs/2206.13236>`_
The transducer model consists of 3 parts:
- Encoder, a.k.a, transcriber. We use an LSTM model
- Decoder, a.k.a, predictor. We use a model consisting of ``nn.Embedding``
and ``nn.Conv1d``
- Joiner, a.k.a, the joint network.
.. caution::
Contrary to the conventional RNN-T models, we use a stateless decoder.
That is, it has no recurrent connections.
.. hint::
Since the encoder model is an LSTM, not Transformer/Conformer, the
resulting model is suitable for streaming/online ASR.
Which model to use
------------------
Currently, there are two folders about LSTM stateless transducer training:
- ``(1)`` `<https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/lstm_transducer_stateless>`_
This recipe uses only LibriSpeech during training.
- ``(2)`` `<https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/lstm_transducer_stateless2>`_
This recipe uses GigaSpeech + LibriSpeech during training.
``(1)`` and ``(2)`` use the same model architecture. The only difference is that ``(2)`` supports
multi-dataset. Since ``(2)`` uses more data, it has a lower WER than ``(1)`` but it needs
more training time.
We use ``lstm_transducer_stateless2`` as an example below.
.. note::
You need to download the `GigaSpeech <https://github.com/SpeechColab/GigaSpeech>`_ dataset
to run ``(2)``. If you have only ``LibriSpeech`` dataset available, feel free to use ``(1)``.
Data preparation
----------------
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./prepare.sh
# If you use (1), you can **skip** the following command
$ ./prepare_giga_speech.sh
The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
All you need to do is to run it.
The data preparation contains several stages, you can use the following two
options:
- ``--stage``
- ``--stop-stage``
to control which stage(s) should be run. By default, all stages are executed.
For example,
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./prepare.sh --stage 0 --stop-stage 0
means to run only stage 0.
To run stage 2 to stage 5, use:
.. code-block:: bash
$ ./prepare.sh --stage 2 --stop-stage 5
.. hint::
If you have pre-downloaded the `LibriSpeech <https://www.openslr.org/12>`_
dataset and the `musan <http://www.openslr.org/17/>`_ dataset, say,
they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify
the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that
``./prepare.sh`` won't re-download them.
.. note::
All generated files by ``./prepare.sh``, e.g., features, lexicon, etc,
are saved in ``./data`` directory.
We provide the following YouTube video showing how to run ``./prepare.sh``.
.. note::
To get the latest news of `next-gen Kaldi <https://github.com/k2-fsa>`_, please subscribe
the following YouTube channel by `Nadira Povey <https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_:
`<https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_
.. youtube:: ofEIoJL-mGM
Training
--------
Configurable options
~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./lstm_transducer_stateless2/train.py --help
shows you the training options that can be passed from the commandline.
The following options are used quite often:
- ``--full-libri``
If it's True, the training part uses all the training data, i.e.,
960 hours. Otherwise, the training part uses only the subset
``train-clean-100``, which has 100 hours of training data.
.. CAUTION::
The training set is perturbed by speed with two factors: 0.9 and 1.1.
If ``--full-libri`` is True, each epoch actually processes
``3x960 == 2880`` hours of data.
- ``--num-epochs``
It is the number of epochs to train. For instance,
``./lstm_transducer_stateless2/train.py --num-epochs 30`` trains for 30 epochs
and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt``
in the folder ``./lstm_transducer_stateless2/exp``.
- ``--start-epoch``
It's used to resume training.
``./lstm_transducer_stateless2/train.py --start-epoch 10`` loads the
checkpoint ``./lstm_transducer_stateless2/exp/epoch-9.pt`` and starts
training from epoch 10, based on the state from epoch 9.
- ``--world-size``
It is used for multi-GPU single-machine DDP training.
- (a) If it is 1, then no DDP training is used.
- (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training.
The following shows some use cases with it.
**Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and
GPU 2 for training. You can do the following:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ export CUDA_VISIBLE_DEVICES="0,2"
$ ./lstm_transducer_stateless2/train.py --world-size 2
**Use case 2**: You have 4 GPUs and you want to use all of them
for training. You can do the following:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./lstm_transducer_stateless2/train.py --world-size 4
**Use case 3**: You have 4 GPUs but you only want to use GPU 3
for training. You can do the following:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ export CUDA_VISIBLE_DEVICES="3"
$ ./lstm_transducer_stateless2/train.py --world-size 1
.. caution::
Only multi-GPU single-machine DDP training is implemented at present.
Multi-GPU multi-machine DDP training will be added later.
- ``--max-duration``
It specifies the number of seconds over all utterances in a
batch, before **padding**.
If you encounter CUDA OOM, please reduce it.
.. HINT::
Due to padding, the number of seconds of all utterances in a
batch will usually be larger than ``--max-duration``.
A larger value for ``--max-duration`` may cause OOM during training,
while a smaller value may increase the training time. You have to
tune it.
- ``--giga-prob``
The probability to select a batch from the ``GigaSpeech`` dataset.
Note: It is available only for ``(2)``.
Pre-configured options
~~~~~~~~~~~~~~~~~~~~~~
There are some training options, e.g., weight decay,
number of warmup steps, results dir, etc,
that are not passed from the commandline.
They are pre-configured by the function ``get_params()`` in
`lstm_transducer_stateless2/train.py <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/lstm_transducer_stateless2/train.py>`_
You don't need to change these pre-configured parameters. If you really need to change
them, please modify ``./lstm_transducer_stateless2/train.py`` directly.
Training logs
~~~~~~~~~~~~~
Training logs and checkpoints are saved in ``lstm_transducer_stateless2/exp``.
You will find the following files in that directory:
- ``epoch-1.pt``, ``epoch-2.pt``, ...
These are checkpoint files saved at the end of each epoch, containing model
``state_dict`` and optimizer ``state_dict``.
To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
.. code-block:: bash
$ ./lstm_transducer_stateless2/train.py --start-epoch 11
- ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ...
These are checkpoint files saved every ``--save-every-n`` batches,
containing model ``state_dict`` and optimizer ``state_dict``.
To resume training from some checkpoint, say ``checkpoint-436000``, you can use:
.. code-block:: bash
$ ./lstm_transducer_stateless2/train.py --start-batch 436000
- ``tensorboard/``
This folder contains TensorBoard logs. Training loss, validation loss, learning
rate, etc, are recorded in these logs. You can visualize them by:
.. code-block:: bash
$ cd lstm_transducer_stateless2/exp/tensorboard
$ tensorboard dev upload --logdir . --description "LSTM transducer training for LibriSpeech with icefall"
It will print something like below:
.. code-block::
TensorFlow installation not found - running with reduced feature set.
Upload started and will continue reading any new data as it's added to the logdir.
To stop uploading, press Ctrl-C.
New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/cj2vtPiwQHKN9Q1tx6PTpg/
[2022-09-20T15:50:50] Started scanning logdir.
Uploading 4468 scalars...
[2022-09-20T15:53:02] Total uploaded: 210171 scalars, 0 tensors, 0 binary objects
Listening for new data in logdir...
Note there is a URL in the above output, click it and you will see
the following screenshot:
.. figure:: images/librispeech-lstm-transducer-tensorboard-log.png
:width: 600
:alt: TensorBoard screenshot
:align: center
:target: https://tensorboard.dev/experiment/lzGnETjwRxC3yghNMd4kPw/
TensorBoard screenshot.
.. hint::
If you don't have access to google, you can use the following command
to view the tensorboard log locally:
.. code-block:: bash
cd lstm_transducer_stateless2/exp/tensorboard
tensorboard --logdir . --port 6008
It will print the following message:
.. code-block::
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit)
Now start your browser and go to `<http://localhost:6008>`_ to view the tensorboard
logs.
- ``log/log-train-xxxx``
It is the detailed training log in text format, same as the one
you saw printed to the console during training.
Usage example
~~~~~~~~~~~~~
You can use the following command to start the training using 8 GPUs:
.. code-block:: bash
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./lstm_transducer_stateless2/train.py \
--world-size 8 \
--num-epochs 35 \
--start-epoch 1 \
--full-libri 1 \
--exp-dir lstm_transducer_stateless2/exp \
--max-duration 500 \
--use-fp16 0 \
--lr-epochs 10 \
--num-workers 2 \
--giga-prob 0.9
Decoding
--------
The decoding part uses checkpoints saved by the training part, so you have
to run the training part first.
.. hint::
There are two kinds of checkpoints:
- (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end
of each epoch. You can pass ``--epoch`` to
``lstm_transducer_stateless2/decode.py`` to use them.
- (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved
every ``--save-every-n`` batches. You can pass ``--iter`` to
``lstm_transducer_stateless2/decode.py`` to use them.
We suggest that you try both types of checkpoints and choose the one
that produces the lowest WERs.
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./lstm_transducer_stateless2/decode.py --help
shows the options for decoding.
The following shows two examples:
.. code-block:: bash
for m in greedy_search fast_beam_search modified_beam_search; do
for epoch in 17; do
for avg in 1 2; do
./lstm_transducer_stateless2/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir lstm_transducer_stateless2/exp \
--max-duration 600 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method $m \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--beam-size 4
done
done
done
.. code-block:: bash
for m in greedy_search fast_beam_search modified_beam_search; do
for iter in 474000; do
for avg in 8 10 12 14 16 18; do
./lstm_transducer_stateless2/decode.py \
--iter $iter \
--avg $avg \
--exp-dir lstm_transducer_stateless2/exp \
--max-duration 600 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method $m \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--beam-size 4
done
done
done
Export models
-------------
`lstm_transducer_stateless2/export.py <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/lstm_transducer_stateless2/export.py>`_ supports to export checkpoints from ``lstm_transducer_stateless2/exp`` in the following ways.
Export ``model.state_dict()``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Checkpoints saved by ``lstm_transducer_stateless2/train.py`` also include
``optimizer.state_dict()``. It is useful for resuming training. But after training,
we are interested only in ``model.state_dict()``. You can use the following
command to extract ``model.state_dict()``.
.. code-block:: bash
# Assume that --iter 468000 --avg 16 produces the smallest WER
# (You can get such information after running ./lstm_transducer_stateless2/decode.py)
iter=468000
avg=16
./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--iter $iter \
--avg $avg
It will generate a file ``./lstm_transducer_stateless2/exp/pretrained.pt``.
.. hint::
To use the generated ``pretrained.pt`` for ``lstm_transducer_stateless2/decode.py``,
you can run:
.. code-block:: bash
cd lstm_transducer_stateless2/exp
ln -s pretrained epoch-9999.pt
And then pass `--epoch 9999 --avg 1 --use-averaged-model 0` to
``./lstm_transducer_stateless2/decode.py``.
To use the exported model with ``./lstm_transducer_stateless2/pretrained.py``, you
can run:
.. code-block:: bash
./lstm_transducer_stateless2/pretrained.py \
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
Export model using ``torch.jit.trace()``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
iter=468000
avg=16
./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--iter $iter \
--avg $avg \
--jit-trace 1
It will generate 3 files:
- ``./lstm_transducer_stateless2/exp/encoder_jit_trace.pt``
- ``./lstm_transducer_stateless2/exp/decoder_jit_trace.pt``
- ``./lstm_transducer_stateless2/exp/joiner_jit_trace.pt``
To use the generated files with ``./lstm_transducer_stateless2/jit_pretrained``:
.. code-block:: bash
./lstm_transducer_stateless2/jit_pretrained.py \
--bpe-model ./data/lang_bpe_500/bpe.model \
--encoder-model-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace.pt \
--decoder-model-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace.pt \
--joiner-model-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace.pt \
/path/to/foo.wav \
/path/to/bar.wav
Export model for ncnn
~~~~~~~~~~~~~~~~~~~~~
We support exporting pretrained LSTM transducer models to
`ncnn <https://github.com/tencent/ncnn>`_ using
`pnnx <https://github.com/Tencent/ncnn/tree/master/tools/pnnx>`_.
First, let us install a modified version of ``ncnn``:
.. code-block:: bash
git clone https://github.com/csukuangfj/ncnn
cd ncnn
git submodule update --recursive --init
python3 setup.py bdist_wheel
ls -lh dist/
pip install ./dist/*.whl
# now build pnnx
cd tools/pnnx
mkdir build
cd build
make -j4
export PATH=$PWD/src:$PATH
./src/pnnx
.. note::
We assume that you have added the path to the binary ``pnnx`` to the
environment variable ``PATH``.
Second, let us export the model using ``torch.jit.trace()`` that is suitable
for ``pnnx``:
.. code-block:: bash
iter=468000
avg=16
./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--iter $iter \
--avg $avg \
--pnnx 1
It will generate 3 files:
- ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt``
- ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt``
- ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt``
Third, convert torchscript model to ``ncnn`` format:
.. code-block::
pnnx ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt
pnnx ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt
pnnx ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt
It will generate the following files:
- ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param``
- ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin``
- ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param``
- ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin``
- ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param``
- ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin``
To use the above generate files, run:
.. code-block:: bash
./lstm_transducer_stateless2/ncnn-decode.py \
--bpe-model-filename ./data/lang_bpe_500/bpe.model \
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin \
/path/to/foo.wav
.. code-block:: bash
./lstm_transducer_stateless2/streaming-ncnn-decode.py \
--bpe-model-filename ./data/lang_bpe_500/bpe.model \
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin \
/path/to/foo.wav
To use the above generated files in C++, please see
`<https://github.com/k2-fsa/sherpa-ncnn>`_
It is able to generate a static linked library that can be run on Linux, Windows,
macOS, Raspberry Pi, etc.
Download pretrained models
--------------------------
If you don't want to train from scratch, you can download the pretrained models
by visiting the following links:
- `<https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03>`_
- `<https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18>`_
See `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>`_
for the details of the above pretrained models
You can find more usages of the pretrained models in
`<https://k2-fsa.github.io/sherpa/python/streaming_asr/lstm/index.html>`_

View File

@ -116,6 +116,8 @@ class RNN(EncoderInterface):
Period of auxiliary layers used for random combiner during training.
If set to 0, will not use the random combiner (Default).
You can set a positive integer to use the random combiner, e.g., 3.
is_pnnx:
True to make this class exportable via PNNX.
"""
def __init__(
@ -129,6 +131,7 @@ class RNN(EncoderInterface):
dropout: float = 0.1,
layer_dropout: float = 0.075,
aux_layer_period: int = 0,
is_pnnx: bool = False,
) -> None:
super(RNN, self).__init__()
@ -142,7 +145,13 @@ class RNN(EncoderInterface):
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_embed = Conv2dSubsampling(
num_features,
d_model,
is_pnnx=is_pnnx,
)
self.is_pnnx = is_pnnx
self.num_encoder_layers = num_encoder_layers
self.d_model = d_model
@ -209,7 +218,13 @@ class RNN(EncoderInterface):
# lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 3) >> 1) - 1) >> 1
if not self.is_pnnx:
lengths = (((x_lens - 3) >> 1) - 1) >> 1
else:
lengths1 = torch.floor((x_lens - 3) / 2)
lengths = torch.floor((lengths1 - 1) / 2)
lengths = lengths.to(x_lens)
if not torch.jit.is_tracing():
assert x.size(0) == lengths.max().item()
@ -359,7 +374,7 @@ class RNNEncoderLayer(nn.Module):
# for cell state
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
src_lstm, new_states = self.lstm(src, states)
src = src + self.dropout(src_lstm)
src = self.dropout(src_lstm) + src
# feed forward module
src = src + self.dropout(self.feed_forward(src))
@ -505,6 +520,7 @@ class Conv2dSubsampling(nn.Module):
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
is_pnnx: bool = False,
) -> None:
"""
Args:
@ -517,6 +533,9 @@ class Conv2dSubsampling(nn.Module):
Number of channels in layer1
layer1_channels:
Number of channels in layer2
is_pnnx:
True if we are converting the model to PNNX format.
False otherwise.
"""
assert in_channels >= 9
super().__init__()
@ -559,6 +578,10 @@ class Conv2dSubsampling(nn.Module):
channel_dim=-1, min_positive=0.45, max_positive=0.55
)
# ncnn supports only batch size == 1
self.is_pnnx = is_pnnx
self.conv_out_dim = self.out.weight.shape[1]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
@ -572,9 +595,15 @@ class Conv2dSubsampling(nn.Module):
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if torch.jit.is_tracing() and self.is_pnnx:
x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim)
x = self.out(x)
else:
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-3)//2-1))//2, odim)
x = self.out_norm(x)
x = self.out_balancer(x)

View File

@ -169,6 +169,18 @@ def get_parser():
""",
)
parser.add_argument(
"--pnnx",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.trace for later
converting to PNNX. It will generate 3 files:
- encoder_jit_trace-pnnx.pt
- decoder_jit_trace-pnnx.pt
- joiner_jit_trace-pnnx.pt
""",
)
parser.add_argument(
"--context-size",
type=int,
@ -277,6 +289,10 @@ def main():
logging.info(params)
if params.pnnx:
params.is_pnnx = params.pnnx
logging.info("For PNNX")
logging.info("About to create model")
model = get_transducer_model(params, enable_giga=False)
@ -371,7 +387,18 @@ def main():
model.to("cpu")
model.eval()
if params.jit_trace is True:
if params.pnnx:
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.trace()")
encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
elif params.jit_trace is True:
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.trace()")
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"

View File

@ -0,0 +1,295 @@
#!/usr/bin/env python3
# flake8: noqa
#
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./lstm_transducer_stateless2/ncnn-decode.py \
--bpe-model-filename ./data/lang_bpe_500/bpe.model \
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
--decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
--joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
--joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
./test_wavs/1089-134686-0001.wav
"""
import argparse
import logging
from typing import List
import kaldifeat
import ncnn
import sentencepiece as spm
import torch
import torchaudio
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe-model-filename",
type=str,
help="Path to bpe.model",
)
parser.add_argument(
"--encoder-param-filename",
type=str,
help="Path to encoder.ncnn.param",
)
parser.add_argument(
"--encoder-bin-filename",
type=str,
help="Path to encoder.ncnn.bin",
)
parser.add_argument(
"--decoder-param-filename",
type=str,
help="Path to decoder.ncnn.param",
)
parser.add_argument(
"--decoder-bin-filename",
type=str,
help="Path to decoder.ncnn.bin",
)
parser.add_argument(
"--joiner-param-filename",
type=str,
help="Path to joiner.ncnn.param",
)
parser.add_argument(
"--joiner-bin-filename",
type=str,
help="Path to joiner.ncnn.bin",
)
parser.add_argument(
"sound_filename",
type=str,
help="Path to foo.wav",
)
return parser.parse_args()
class Model:
def __init__(self, args):
self.init_encoder(args)
self.init_decoder(args)
self.init_joiner(args)
def init_encoder(self, args):
encoder_net = ncnn.Net()
encoder_net.opt.use_packing_layout = False
encoder_net.opt.use_fp16_storage = False
encoder_param = args.encoder_param_filename
encoder_model = args.encoder_bin_filename
encoder_net.load_param(encoder_param)
encoder_net.load_model(encoder_model)
self.encoder_net = encoder_net
def init_decoder(self, args):
decoder_param = args.decoder_param_filename
decoder_model = args.decoder_bin_filename
decoder_net = ncnn.Net()
decoder_net.opt.use_packing_layout = False
decoder_net.load_param(decoder_param)
decoder_net.load_model(decoder_model)
self.decoder_net = decoder_net
def init_joiner(self, args):
joiner_param = args.joiner_param_filename
joiner_model = args.joiner_bin_filename
joiner_net = ncnn.Net()
joiner_net.opt.use_packing_layout = False
joiner_net.load_param(joiner_param)
joiner_net.load_model(joiner_model)
self.joiner_net = joiner_net
def run_encoder(self, x, states):
with self.encoder_net.create_extractor() as ex:
ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(x.numpy()).clone())
x_lens = torch.tensor([x.size(0)], dtype=torch.float32)
ex.input("in1", ncnn.Mat(x_lens.numpy()).clone())
ex.input("in2", ncnn.Mat(states[0].numpy()).clone())
ex.input("in3", ncnn.Mat(states[1].numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
ret, ncnn_out1 = ex.extract("out1")
assert ret == 0, ret
ret, ncnn_out2 = ex.extract("out2")
assert ret == 0, ret
ret, ncnn_out3 = ex.extract("out3")
assert ret == 0, ret
encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(
torch.int32
)
hx = torch.from_numpy(ncnn_out2.numpy()).clone()
cx = torch.from_numpy(ncnn_out3.numpy()).clone()
return encoder_out, encoder_out_lens, hx, cx
def run_decoder(self, decoder_input):
assert decoder_input.dtype == torch.int32
with self.decoder_net.create_extractor() as ex:
ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
return decoder_out
def run_joiner(self, encoder_out, decoder_out):
with self.joiner_net.create_extractor() as ex:
ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone()
return joiner_out
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(model: Model, encoder_out: torch.Tensor):
assert encoder_out.ndim == 2
T = encoder_out.size(0)
context_size = 2
blank_id = 0 # hard-code to 0
hyp = [blank_id] * context_size
decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size)
decoder_out = model.run_decoder(decoder_input).squeeze(0)
# print(decoder_out.shape) # (512,)
for t in range(T):
encoder_out_t = encoder_out[t]
joiner_out = model.run_joiner(encoder_out_t, decoder_out)
# print(joiner_out.shape) # [500]
y = joiner_out.argmax(dim=0).tolist()
if y != blank_id:
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = torch.tensor(decoder_input, dtype=torch.int32)
decoder_out = model.run_decoder(decoder_input).squeeze(0)
return hyp[context_size:]
def main():
args = get_args()
logging.info(vars(args))
model = Model(args)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model_filename)
sound_file = args.sound_filename
sample_rate = 16000
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {sound_file}")
wave_samples = read_sound_files(
filenames=[sound_file],
expected_sample_rate=sample_rate,
)[0]
logging.info("Decoding started")
features = fbank(wave_samples)
num_encoder_layers = 12
d_model = 512
rnn_hidden_size = 1024
states = (
torch.zeros(num_encoder_layers, d_model),
torch.zeros(
num_encoder_layers,
rnn_hidden_size,
),
)
encoder_out, encoder_out_lens, hx, cx = model.run_encoder(features, states)
hyp = greedy_search(model, encoder_out)
logging.info(sound_file)
logging.info(sp.decode(hyp))
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,353 @@
#!/usr/bin/env python3
# flake8: noqa
#
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from typing import List, Optional
import ncnn
import sentencepiece as spm
import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe-model-filename",
type=str,
help="Path to bpe.model",
)
parser.add_argument(
"--encoder-param-filename",
type=str,
help="Path to encoder.ncnn.param",
)
parser.add_argument(
"--encoder-bin-filename",
type=str,
help="Path to encoder.ncnn.bin",
)
parser.add_argument(
"--decoder-param-filename",
type=str,
help="Path to decoder.ncnn.param",
)
parser.add_argument(
"--decoder-bin-filename",
type=str,
help="Path to decoder.ncnn.bin",
)
parser.add_argument(
"--joiner-param-filename",
type=str,
help="Path to joiner.ncnn.param",
)
parser.add_argument(
"--joiner-bin-filename",
type=str,
help="Path to joiner.ncnn.bin",
)
parser.add_argument(
"sound_filename",
type=str,
help="Path to foo.wav",
)
return parser.parse_args()
class Model:
def __init__(self, args):
self.init_encoder(args)
self.init_decoder(args)
self.init_joiner(args)
def init_encoder(self, args):
encoder_net = ncnn.Net()
encoder_net.opt.use_packing_layout = False
encoder_net.opt.use_fp16_storage = False
encoder_param = args.encoder_param_filename
encoder_model = args.encoder_bin_filename
encoder_net.load_param(encoder_param)
encoder_net.load_model(encoder_model)
self.encoder_net = encoder_net
def init_decoder(self, args):
decoder_param = args.decoder_param_filename
decoder_model = args.decoder_bin_filename
decoder_net = ncnn.Net()
decoder_net.opt.use_packing_layout = False
decoder_net.load_param(decoder_param)
decoder_net.load_model(decoder_model)
self.decoder_net = decoder_net
def init_joiner(self, args):
joiner_param = args.joiner_param_filename
joiner_model = args.joiner_bin_filename
joiner_net = ncnn.Net()
joiner_net.opt.use_packing_layout = False
joiner_net.load_param(joiner_param)
joiner_net.load_model(joiner_model)
self.joiner_net = joiner_net
def run_encoder(self, x, states):
with self.encoder_net.create_extractor() as ex:
# ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(x.numpy()).clone())
x_lens = torch.tensor([x.size(0)], dtype=torch.float32)
ex.input("in1", ncnn.Mat(x_lens.numpy()).clone())
ex.input("in2", ncnn.Mat(states[0].numpy()).clone())
ex.input("in3", ncnn.Mat(states[1].numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
ret, ncnn_out1 = ex.extract("out1")
assert ret == 0, ret
ret, ncnn_out2 = ex.extract("out2")
assert ret == 0, ret
ret, ncnn_out3 = ex.extract("out3")
assert ret == 0, ret
encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(
torch.int32
)
hx = torch.from_numpy(ncnn_out2.numpy()).clone()
cx = torch.from_numpy(ncnn_out3.numpy()).clone()
return encoder_out, encoder_out_lens, hx, cx
def run_decoder(self, decoder_input):
assert decoder_input.dtype == torch.int32
with self.decoder_net.create_extractor() as ex:
# ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
return decoder_out
def run_joiner(self, encoder_out, decoder_out):
with self.joiner_net.create_extractor() as ex:
# ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone()
return joiner_out
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def create_streaming_feature_extractor() -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return OnlineFbank(opts)
def greedy_search(
model: Model,
encoder_out: torch.Tensor,
decoder_out: Optional[torch.Tensor] = None,
hyp: Optional[List[int]] = None,
):
assert encoder_out.ndim == 1
context_size = 2
blank_id = 0
if decoder_out is None:
assert hyp is None, hyp
hyp = [blank_id] * context_size
decoder_input = torch.tensor(
hyp, dtype=torch.int32
) # (1, context_size)
decoder_out = model.run_decoder(decoder_input).squeeze(0)
else:
assert decoder_out.ndim == 1
assert hyp is not None, hyp
joiner_out = model.run_joiner(encoder_out, decoder_out)
y = joiner_out.argmax(dim=0).tolist()
if y != blank_id:
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = torch.tensor(decoder_input, dtype=torch.int32)
decoder_out = model.run_decoder(decoder_input).squeeze(0)
return hyp, decoder_out
def main():
args = get_args()
logging.info(vars(args))
model = Model(args)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model_filename)
sound_file = args.sound_filename
sample_rate = 16000
logging.info("Constructing Fbank computer")
online_fbank = create_streaming_feature_extractor()
logging.info(f"Reading sound files: {sound_file}")
wave_samples = read_sound_files(
filenames=[sound_file],
expected_sample_rate=sample_rate,
)[0]
logging.info(wave_samples.shape)
num_encoder_layers = 12
batch_size = 1
d_model = 512
rnn_hidden_size = 1024
states = (
torch.zeros(num_encoder_layers, batch_size, d_model),
torch.zeros(
num_encoder_layers,
batch_size,
rnn_hidden_size,
),
)
hyp = None
decoder_out = None
num_processed_frames = 0
segment = 9
offset = 4
chunk = 3200 # 0.2 second
start = 0
while start < wave_samples.numel():
end = min(start + chunk, wave_samples.numel())
samples = wave_samples[start:end]
start += chunk
online_fbank.accept_waveform(
sampling_rate=sample_rate,
waveform=samples,
)
while online_fbank.num_frames_ready - num_processed_frames >= segment:
frames = []
for i in range(segment):
frames.append(online_fbank.get_frame(num_processed_frames + i))
num_processed_frames += offset
frames = torch.cat(frames, dim=0)
encoder_out, encoder_out_lens, hx, cx = model.run_encoder(
frames, states
)
states = (hx, cx)
hyp, decoder_out = greedy_search(
model, encoder_out.squeeze(0), decoder_out, hyp
)
online_fbank.accept_waveform(
sampling_rate=sample_rate, waveform=torch.zeros(8000, dtype=torch.int32)
)
online_fbank.input_finished()
while online_fbank.num_frames_ready - num_processed_frames >= segment:
frames = []
for i in range(segment):
frames.append(online_fbank.get_frame(num_processed_frames + i))
num_processed_frames += offset
frames = torch.cat(frames, dim=0)
encoder_out, encoder_out_lens, hx, cx = model.run_encoder(
frames, states
)
states = (hx, cx)
hyp, decoder_out = greedy_search(
model, encoder_out.squeeze(0), decoder_out, hyp
)
context_size = 2
logging.info(sound_file)
logging.info(sp.decode(hyp[context_size:]))
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -406,6 +406,8 @@ def get_params() -> AttributeDict:
"decoder_dim": 512,
# parameters for joiner
"joiner_dim": 512,
# True to generate a model that can be exported via PNNX
"is_pnnx": False,
# parameters for Noam
"model_warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(),
@ -424,6 +426,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
aux_layer_period=params.aux_layer_period,
is_pnnx=params.is_pnnx,
)
return encoder

View File

@ -30,6 +30,7 @@ from typing import List
import torch
import torch.nn as nn
from scaling import (
BasicNorm,
ScaledConv1d,
ScaledConv2d,
ScaledEmbedding,
@ -38,6 +39,29 @@ from scaling import (
)
class NonScaledNorm(nn.Module):
"""See BasicNorm for doc"""
def __init__(
self,
num_channels: int,
eps_exp: float,
channel_dim: int = -1, # CAUTION: see documentation.
):
super().__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
self.eps_exp = eps_exp
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not torch.jit.is_tracing():
assert x.shape[self.channel_dim] == self.num_channels
scales = (
torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp
).pow(-0.5)
return x * scales
def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
"""Convert an instance of ScaledLinear to nn.Linear.
@ -174,6 +198,16 @@ def scaled_embedding_to_embedding(
return embedding
def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
assert isinstance(basic_norm, BasicNorm), type(BasicNorm)
norm = NonScaledNorm(
num_channels=basic_norm.num_channels,
eps_exp=basic_norm.eps.data.exp().item(),
channel_dim=basic_norm.channel_dim,
)
return norm
def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM:
"""Convert an instance of ScaledLSTM to nn.LSTM.
@ -256,6 +290,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
d[name] = scaled_conv2d_to_conv2d(m)
elif isinstance(m, ScaledEmbedding):
d[name] = scaled_embedding_to_embedding(m)
elif isinstance(m, BasicNorm):
d[name] = convert_basic_norm(m)
elif isinstance(m, ScaledLSTM):
d[name] = scaled_lstm_to_lstm(m)