diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml new file mode 100755 index 000000000..19d606682 --- /dev/null +++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -0,0 +1,160 @@ +#!/usr/bin/env bash + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +ln -s pretrained-iter-468000-avg-16.pt pretrained.pt +ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +log "Install ncnn and pnnx" + +# We are using a modified ncnn here. Will try to merge it to the official repo +# of ncnn +git clone https://github.com/csukuangfj/ncnn +pushd ncnn +git submodule init +git submodule update python/pybind11 +python3 setup.py bdist_wheel +ls -lh dist/ +pip install dist/*.whl +cd tools/pnnx +mkdir build +cd build +cmake .. +make -j4 pnnx + +./src/pnnx || echo "pass" + +popd + +log "Test exporting to pnnx format" + +./lstm_transducer_stateless2/export.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --pnnx 1 + +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt + +./lstm_transducer_stateless2/ncnn-decode.py \ + --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ + --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/1089-134686-0001.wav + +./lstm_transducer_stateless2/streaming-ncnn-decode.py \ + --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ + --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/1089-134686-0001.wav + + + +log "Test exporting with torch.jit.trace()" + +./lstm_transducer_stateless2/export.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --jit-trace 1 + +log "Decode with models exported by torch.jit.trace()" + +./lstm_transducer_stateless2/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder_jit_trace.pt \ + --decoder-model-filename $repo/exp/decoder_jit_trace.pt \ + --joiner-model-filename $repo/exp/joiner_jit_trace.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./lstm_transducer_stateless2/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./lstm_transducer_stateless2/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"ncnn" ]]; then + mkdir -p lstm_transducer_stateless2/exp + ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh lstm_transducer_stateless2/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./lstm_transducer_stateless2/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --max-duration $max_duration \ + --exp-dir lstm_transducer_stateless2/exp + done + + rm lstm_transducer_stateless2/exp/*.pt +fi diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml new file mode 100644 index 000000000..a3aa0c0b8 --- /dev/null +++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -0,0 +1,136 @@ +name: run-librispeech-lstm-transducer-2022-09-03 + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +jobs: + run_librispeech_pruned_transducer_stateless3_2022_05_13: + if: github.event.label.name == 'ncnn' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-18.04] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }} + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Cache LibriSpeech test-clean and test-other datasets + id: libri-test-clean-and-test-other-data + uses: actions/cache@v2 + with: + path: | + ~/tmp/download + key: cache-libri-test-clean-and-test-other + + - name: Download LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh + + - name: Prepare manifests for LibriSpeech test-clean and test-other + shell: bash + run: | + .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh + + - name: Cache LibriSpeech test-clean and test-other fbank features + id: libri-test-clean-and-test-other-fbank + uses: actions/cache@v2 + with: + path: | + ~/tmp/fbank-libri + key: cache-libri-fbank-test-clean-and-test-other-v2 + + - name: Compute fbank for LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/librispeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank + ls -lh egs/librispeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree sox + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml + + - name: Display decoding results for lstm_transducer_stateless2 + if: github.event_name == 'schedule' || github.event.label.name == 'ncnn' + shell: bash + run: | + cd egs/librispeech/ASR + tree lstm_transducer_stateless2/exp + cd lstm_transducer_stateless2/exp + echo "===greedy search===" + find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===fast_beam_search===" + find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===modified beam search===" + find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + - name: Upload decoding results for lstm_transducer_stateless2 + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' || github.event.label.name == 'ncnn' + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03 + path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/ diff --git a/.gitignore b/.gitignore index 1dbf8f395..406deff6a 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,5 @@ log *.bak *-bak *bak.py +*.param +*.bin diff --git a/docs/source/recipes/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png b/docs/source/recipes/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png new file mode 100644 index 000000000..cc475a45f Binary files /dev/null and b/docs/source/recipes/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png differ diff --git a/docs/source/recipes/librispeech/index.rst b/docs/source/recipes/librispeech/index.rst index 5fa08ab6b..6c91b6750 100644 --- a/docs/source/recipes/librispeech/index.rst +++ b/docs/source/recipes/librispeech/index.rst @@ -6,3 +6,4 @@ LibriSpeech tdnn_lstm_ctc conformer_ctc + lstm_pruned_stateless_transducer diff --git a/docs/source/recipes/librispeech/lstm_pruned_stateless_transducer.rst b/docs/source/recipes/librispeech/lstm_pruned_stateless_transducer.rst new file mode 100644 index 000000000..0aeccb70a --- /dev/null +++ b/docs/source/recipes/librispeech/lstm_pruned_stateless_transducer.rst @@ -0,0 +1,625 @@ +Transducer +========== + +.. hint:: + + Please scroll down to the bottom of this page to find download links + for pretrained models if you don't want to train a model from scratch. + + +This tutorial shows you how to train a transducer model +with the `LibriSpeech `_ dataset. + +We use pruned RNN-T to compute the loss. + +.. note:: + + You can find the paper about pruned RNN-T at the following address: + + ``_ + +The transducer model consists of 3 parts: + + - Encoder, a.k.a, transcriber. We use an LSTM model + - Decoder, a.k.a, predictor. We use a model consisting of ``nn.Embedding`` + and ``nn.Conv1d`` + - Joiner, a.k.a, the joint network. + +.. caution:: + + Contrary to the conventional RNN-T models, we use a stateless decoder. + That is, it has no recurrent connections. + +.. hint:: + + Since the encoder model is an LSTM, not Transformer/Conformer, the + resulting model is suitable for streaming/online ASR. + + +Which model to use +------------------ + +Currently, there are two folders about LSTM stateless transducer training: + + - ``(1)`` ``_ + + This recipe uses only LibriSpeech during training. + + - ``(2)`` ``_ + + This recipe uses GigaSpeech + LibriSpeech during training. + +``(1)`` and ``(2)`` use the same model architecture. The only difference is that ``(2)`` supports +multi-dataset. Since ``(2)`` uses more data, it has a lower WER than ``(1)`` but it needs +more training time. + +We use ``lstm_transducer_stateless2`` as an example below. + +.. note:: + + You need to download the `GigaSpeech `_ dataset + to run ``(2)``. If you have only ``LibriSpeech`` dataset available, feel free to use ``(1)``. + +Data preparation +---------------- + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh + + # If you use (1), you can **skip** the following command + $ ./prepare_giga_speech.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +The data preparation contains several stages, you can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 + +means to run only stage 0. + +To run stage 2 to stage 5, use: + +.. code-block:: bash + + $ ./prepare.sh --stage 2 --stop-stage 5 + +.. hint:: + + If you have pre-downloaded the `LibriSpeech `_ + dataset and the `musan `_ dataset, say, + they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify + the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that + ``./prepare.sh`` won't re-download them. + +.. note:: + + All generated files by ``./prepare.sh``, e.g., features, lexicon, etc, + are saved in ``./data`` directory. + +We provide the following YouTube video showing how to run ``./prepare.sh``. + +.. note:: + + To get the latest news of `next-gen Kaldi `_, please subscribe + the following YouTube channel by `Nadira Povey `_: + + ``_ + +.. youtube:: ofEIoJL-mGM + +Training +-------- + +Configurable options +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./lstm_transducer_stateless2/train.py --help + +shows you the training options that can be passed from the commandline. +The following options are used quite often: + + - ``--full-libri`` + + If it's True, the training part uses all the training data, i.e., + 960 hours. Otherwise, the training part uses only the subset + ``train-clean-100``, which has 100 hours of training data. + + .. CAUTION:: + + The training set is perturbed by speed with two factors: 0.9 and 1.1. + If ``--full-libri`` is True, each epoch actually processes + ``3x960 == 2880`` hours of data. + + - ``--num-epochs`` + + It is the number of epochs to train. For instance, + ``./lstm_transducer_stateless2/train.py --num-epochs 30`` trains for 30 epochs + and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt`` + in the folder ``./lstm_transducer_stateless2/exp``. + + - ``--start-epoch`` + + It's used to resume training. + ``./lstm_transducer_stateless2/train.py --start-epoch 10`` loads the + checkpoint ``./lstm_transducer_stateless2/exp/epoch-9.pt`` and starts + training from epoch 10, based on the state from epoch 9. + + - ``--world-size`` + + It is used for multi-GPU single-machine DDP training. + + - (a) If it is 1, then no DDP training is used. + + - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training. + + The following shows some use cases with it. + + **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and + GPU 2 for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="0,2" + $ ./lstm_transducer_stateless2/train.py --world-size 2 + + **Use case 2**: You have 4 GPUs and you want to use all of them + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./lstm_transducer_stateless2/train.py --world-size 4 + + **Use case 3**: You have 4 GPUs but you only want to use GPU 3 + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="3" + $ ./lstm_transducer_stateless2/train.py --world-size 1 + + .. caution:: + + Only multi-GPU single-machine DDP training is implemented at present. + Multi-GPU multi-machine DDP training will be added later. + + - ``--max-duration`` + + It specifies the number of seconds over all utterances in a + batch, before **padding**. + If you encounter CUDA OOM, please reduce it. + + .. HINT:: + + Due to padding, the number of seconds of all utterances in a + batch will usually be larger than ``--max-duration``. + + A larger value for ``--max-duration`` may cause OOM during training, + while a smaller value may increase the training time. You have to + tune it. + + - ``--giga-prob`` + + The probability to select a batch from the ``GigaSpeech`` dataset. + Note: It is available only for ``(2)``. + +Pre-configured options +~~~~~~~~~~~~~~~~~~~~~~ + +There are some training options, e.g., weight decay, +number of warmup steps, results dir, etc, +that are not passed from the commandline. +They are pre-configured by the function ``get_params()`` in +`lstm_transducer_stateless2/train.py `_ + +You don't need to change these pre-configured parameters. If you really need to change +them, please modify ``./lstm_transducer_stateless2/train.py`` directly. + +Training logs +~~~~~~~~~~~~~ + +Training logs and checkpoints are saved in ``lstm_transducer_stateless2/exp``. +You will find the following files in that directory: + + - ``epoch-1.pt``, ``epoch-2.pt``, ... + + These are checkpoint files saved at the end of each epoch, containing model + ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``epoch-10.pt``, you can use: + + .. code-block:: bash + + $ ./lstm_transducer_stateless2/train.py --start-epoch 11 + + - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ... + + These are checkpoint files saved every ``--save-every-n`` batches, + containing model ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``checkpoint-436000``, you can use: + + .. code-block:: bash + + $ ./lstm_transducer_stateless2/train.py --start-batch 436000 + + - ``tensorboard/`` + + This folder contains TensorBoard logs. Training loss, validation loss, learning + rate, etc, are recorded in these logs. You can visualize them by: + + .. code-block:: bash + + $ cd lstm_transducer_stateless2/exp/tensorboard + $ tensorboard dev upload --logdir . --description "LSTM transducer training for LibriSpeech with icefall" + + It will print something like below: + + .. code-block:: + + TensorFlow installation not found - running with reduced feature set. + Upload started and will continue reading any new data as it's added to the logdir. + + To stop uploading, press Ctrl-C. + + New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/cj2vtPiwQHKN9Q1tx6PTpg/ + + [2022-09-20T15:50:50] Started scanning logdir. + Uploading 4468 scalars... + [2022-09-20T15:53:02] Total uploaded: 210171 scalars, 0 tensors, 0 binary objects + Listening for new data in logdir... + + Note there is a URL in the above output, click it and you will see + the following screenshot: + + .. figure:: images/librispeech-lstm-transducer-tensorboard-log.png + :width: 600 + :alt: TensorBoard screenshot + :align: center + :target: https://tensorboard.dev/experiment/lzGnETjwRxC3yghNMd4kPw/ + + TensorBoard screenshot. + + .. hint:: + + If you don't have access to google, you can use the following command + to view the tensorboard log locally: + + .. code-block:: bash + + cd lstm_transducer_stateless2/exp/tensorboard + tensorboard --logdir . --port 6008 + + It will print the following message: + + .. code-block:: + + Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all + TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit) + + Now start your browser and go to ``_ to view the tensorboard + logs. + + + - ``log/log-train-xxxx`` + + It is the detailed training log in text format, same as the one + you saw printed to the console during training. + +Usage example +~~~~~~~~~~~~~ + +You can use the following command to start the training using 8 GPUs: + +.. code-block:: bash + + export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + ./lstm_transducer_stateless2/train.py \ + --world-size 8 \ + --num-epochs 35 \ + --start-epoch 1 \ + --full-libri 1 \ + --exp-dir lstm_transducer_stateless2/exp \ + --max-duration 500 \ + --use-fp16 0 \ + --lr-epochs 10 \ + --num-workers 2 \ + --giga-prob 0.9 + +Decoding +-------- + +The decoding part uses checkpoints saved by the training part, so you have +to run the training part first. + +.. hint:: + + There are two kinds of checkpoints: + + - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end + of each epoch. You can pass ``--epoch`` to + ``lstm_transducer_stateless2/decode.py`` to use them. + + - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved + every ``--save-every-n`` batches. You can pass ``--iter`` to + ``lstm_transducer_stateless2/decode.py`` to use them. + + We suggest that you try both types of checkpoints and choose the one + that produces the lowest WERs. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./lstm_transducer_stateless2/decode.py --help + +shows the options for decoding. + +The following shows two examples: + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for epoch in 17; do + for avg in 1 2; do + ./lstm_transducer_stateless2/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir lstm_transducer_stateless2/exp \ + --max-duration 600 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method $m \ + --use-averaged-model True \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + --beam-size 4 + done + done + done + + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for iter in 474000; do + for avg in 8 10 12 14 16 18; do + ./lstm_transducer_stateless2/decode.py \ + --iter $iter \ + --avg $avg \ + --exp-dir lstm_transducer_stateless2/exp \ + --max-duration 600 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method $m \ + --use-averaged-model True \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + --beam-size 4 + done + done + done + +Export models +------------- + +`lstm_transducer_stateless2/export.py `_ supports to export checkpoints from ``lstm_transducer_stateless2/exp`` in the following ways. + +Export ``model.state_dict()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Checkpoints saved by ``lstm_transducer_stateless2/train.py`` also include +``optimizer.state_dict()``. It is useful for resuming training. But after training, +we are interested only in ``model.state_dict()``. You can use the following +command to extract ``model.state_dict()``. + +.. code-block:: bash + + # Assume that --iter 468000 --avg 16 produces the smallest WER + # (You can get such information after running ./lstm_transducer_stateless2/decode.py) + + iter=468000 + avg=16 + + ./lstm_transducer_stateless2/export.py \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --iter $iter \ + --avg $avg + +It will generate a file ``./lstm_transducer_stateless2/exp/pretrained.pt``. + +.. hint:: + + To use the generated ``pretrained.pt`` for ``lstm_transducer_stateless2/decode.py``, + you can run: + + .. code-block:: bash + + cd lstm_transducer_stateless2/exp + ln -s pretrained epoch-9999.pt + + And then pass `--epoch 9999 --avg 1 --use-averaged-model 0` to + ``./lstm_transducer_stateless2/decode.py``. + +To use the exported model with ``./lstm_transducer_stateless2/pretrained.py``, you +can run: + +.. code-block:: bash + + ./lstm_transducer_stateless2/pretrained.py \ + --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +Export model using ``torch.jit.trace()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + iter=468000 + avg=16 + + ./lstm_transducer_stateless2/export.py \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --iter $iter \ + --avg $avg \ + --jit-trace 1 + +It will generate 3 files: + + - ``./lstm_transducer_stateless2/exp/encoder_jit_trace.pt`` + - ``./lstm_transducer_stateless2/exp/decoder_jit_trace.pt`` + - ``./lstm_transducer_stateless2/exp/joiner_jit_trace.pt`` + +To use the generated files with ``./lstm_transducer_stateless2/jit_pretrained``: + +.. code-block:: bash + + ./lstm_transducer_stateless2/jit_pretrained.py \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --encoder-model-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace.pt \ + /path/to/foo.wav \ + /path/to/bar.wav + +Export model for ncnn +~~~~~~~~~~~~~~~~~~~~~ + +We support exporting pretrained LSTM transducer models to +`ncnn `_ using +`pnnx `_. + +First, let us install a modified version of ``ncnn``: + +.. code-block:: bash + + git clone https://github.com/csukuangfj/ncnn + cd ncnn + git submodule update --recursive --init + python3 setup.py bdist_wheel + ls -lh dist/ + pip install ./dist/*.whl + + # now build pnnx + cd tools/pnnx + mkdir build + cd build + make -j4 + export PATH=$PWD/src:$PATH + + ./src/pnnx + +.. note:: + + We assume that you have added the path to the binary ``pnnx`` to the + environment variable ``PATH``. + +Second, let us export the model using ``torch.jit.trace()`` that is suitable +for ``pnnx``: + +.. code-block:: bash + + iter=468000 + avg=16 + + ./lstm_transducer_stateless2/export.py \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --iter $iter \ + --avg $avg \ + --pnnx 1 + +It will generate 3 files: + + - ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt`` + - ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt`` + - ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt`` + +Third, convert torchscript model to ``ncnn`` format: + +.. code-block:: + + pnnx ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt + pnnx ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt + pnnx ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt + +It will generate the following files: + + - ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param`` + - ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin`` + - ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param`` + - ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin`` + - ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param`` + - ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin`` + +To use the above generate files, run: + +.. code-block:: bash + +./lstm_transducer_stateless2/ncnn-decode.py \ + --bpe-model-filename ./data/lang_bpe_500/bpe.model \ + --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin \ + /path/to/foo.wav + +.. code-block:: bash + +./lstm_transducer_stateless2/streaming-ncnn-decode.py \ + --bpe-model-filename ./data/lang_bpe_500/bpe.model \ + --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin \ + /path/to/foo.wav + +To use the above generated files in C++, please see +``_ + +It is able to generate a static linked library that can be run on Linux, Windows, +macOS, Raspberry Pi, etc. + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following links: + + - ``_ + + - ``_ + + See ``_ + for the details of the above pretrained models + +You can find more usages of the pretrained models in +``_ diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py index 0d268ab07..c54a4c478 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py @@ -116,6 +116,8 @@ class RNN(EncoderInterface): Period of auxiliary layers used for random combiner during training. If set to 0, will not use the random combiner (Default). You can set a positive integer to use the random combiner, e.g., 3. + is_pnnx: + True to make this class exportable via PNNX. """ def __init__( @@ -129,6 +131,7 @@ class RNN(EncoderInterface): dropout: float = 0.1, layer_dropout: float = 0.075, aux_layer_period: int = 0, + is_pnnx: bool = False, ) -> None: super(RNN, self).__init__() @@ -142,7 +145,13 @@ class RNN(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_features -> d_model - self.encoder_embed = Conv2dSubsampling(num_features, d_model) + self.encoder_embed = Conv2dSubsampling( + num_features, + d_model, + is_pnnx=is_pnnx, + ) + + self.is_pnnx = is_pnnx self.num_encoder_layers = num_encoder_layers self.d_model = d_model @@ -209,7 +218,13 @@ class RNN(EncoderInterface): # lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning # # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 - lengths = (((x_lens - 3) >> 1) - 1) >> 1 + if not self.is_pnnx: + lengths = (((x_lens - 3) >> 1) - 1) >> 1 + else: + lengths1 = torch.floor((x_lens - 3) / 2) + lengths = torch.floor((lengths1 - 1) / 2) + lengths = lengths.to(x_lens) + if not torch.jit.is_tracing(): assert x.size(0) == lengths.max().item() @@ -359,7 +374,7 @@ class RNNEncoderLayer(nn.Module): # for cell state assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) src_lstm, new_states = self.lstm(src, states) - src = src + self.dropout(src_lstm) + src = self.dropout(src_lstm) + src # feed forward module src = src + self.dropout(self.feed_forward(src)) @@ -505,6 +520,7 @@ class Conv2dSubsampling(nn.Module): layer1_channels: int = 8, layer2_channels: int = 32, layer3_channels: int = 128, + is_pnnx: bool = False, ) -> None: """ Args: @@ -517,6 +533,9 @@ class Conv2dSubsampling(nn.Module): Number of channels in layer1 layer1_channels: Number of channels in layer2 + is_pnnx: + True if we are converting the model to PNNX format. + False otherwise. """ assert in_channels >= 9 super().__init__() @@ -559,6 +578,10 @@ class Conv2dSubsampling(nn.Module): channel_dim=-1, min_positive=0.45, max_positive=0.55 ) + # ncnn supports only batch size == 1 + self.is_pnnx = is_pnnx + self.conv_out_dim = self.out.weight.shape[1] + def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -572,9 +595,15 @@ class Conv2dSubsampling(nn.Module): # On entry, x is (N, T, idim) x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) x = self.conv(x) - # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + + if torch.jit.is_tracing() and self.is_pnnx: + x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) + x = self.out(x) + else: + # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-3)//2-1))//2, odim) x = self.out_norm(x) x = self.out_balancer(x) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index a1ed6b3b1..8c4b243b0 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -169,6 +169,18 @@ def get_parser(): """, ) + parser.add_argument( + "--pnnx", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.trace for later + converting to PNNX. It will generate 3 files: + - encoder_jit_trace-pnnx.pt + - decoder_jit_trace-pnnx.pt + - joiner_jit_trace-pnnx.pt + """, + ) + parser.add_argument( "--context-size", type=int, @@ -277,6 +289,10 @@ def main(): logging.info(params) + if params.pnnx: + params.is_pnnx = params.pnnx + logging.info("For PNNX") + logging.info("About to create model") model = get_transducer_model(params, enable_giga=False) @@ -371,7 +387,18 @@ def main(): model.to("cpu") model.eval() - if params.jit_trace is True: + if params.pnnx: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.trace()") + encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + elif params.jit_trace is True: convert_scaled_to_non_scaled(model, inplace=True) logging.info("Using torch.jit.trace()") encoder_filename = params.exp_dir / "encoder_jit_trace.pt" diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py new file mode 100755 index 000000000..410de8d3d --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 +# flake8: noqa +# +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + ./lstm_transducer_stateless2/ncnn-decode.py \ + --bpe-model-filename ./data/lang_bpe_500/bpe.model \ + --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \ + --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \ + --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \ + --decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \ + --joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \ + --joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \ + ./test_wavs/1089-134686-0001.wav +""" + +import argparse +import logging +from typing import List + +import kaldifeat +import ncnn +import sentencepiece as spm +import torch +import torchaudio + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model-filename", + type=str, + help="Path to bpe.model", + ) + + parser.add_argument( + "--encoder-param-filename", + type=str, + help="Path to encoder.ncnn.param", + ) + + parser.add_argument( + "--encoder-bin-filename", + type=str, + help="Path to encoder.ncnn.bin", + ) + + parser.add_argument( + "--decoder-param-filename", + type=str, + help="Path to decoder.ncnn.param", + ) + + parser.add_argument( + "--decoder-bin-filename", + type=str, + help="Path to decoder.ncnn.bin", + ) + + parser.add_argument( + "--joiner-param-filename", + type=str, + help="Path to joiner.ncnn.param", + ) + + parser.add_argument( + "--joiner-bin-filename", + type=str, + help="Path to joiner.ncnn.bin", + ) + + parser.add_argument( + "sound_filename", + type=str, + help="Path to foo.wav", + ) + + return parser.parse_args() + + +class Model: + def __init__(self, args): + self.init_encoder(args) + self.init_decoder(args) + self.init_joiner(args) + + def init_encoder(self, args): + encoder_net = ncnn.Net() + encoder_net.opt.use_packing_layout = False + encoder_net.opt.use_fp16_storage = False + encoder_param = args.encoder_param_filename + encoder_model = args.encoder_bin_filename + + encoder_net.load_param(encoder_param) + encoder_net.load_model(encoder_model) + + self.encoder_net = encoder_net + + def init_decoder(self, args): + decoder_param = args.decoder_param_filename + decoder_model = args.decoder_bin_filename + + decoder_net = ncnn.Net() + decoder_net.opt.use_packing_layout = False + + decoder_net.load_param(decoder_param) + decoder_net.load_model(decoder_model) + + self.decoder_net = decoder_net + + def init_joiner(self, args): + joiner_param = args.joiner_param_filename + joiner_model = args.joiner_bin_filename + joiner_net = ncnn.Net() + joiner_net.opt.use_packing_layout = False + joiner_net.load_param(joiner_param) + joiner_net.load_model(joiner_model) + + self.joiner_net = joiner_net + + def run_encoder(self, x, states): + with self.encoder_net.create_extractor() as ex: + ex.set_num_threads(10) + ex.input("in0", ncnn.Mat(x.numpy()).clone()) + x_lens = torch.tensor([x.size(0)], dtype=torch.float32) + ex.input("in1", ncnn.Mat(x_lens.numpy()).clone()) + ex.input("in2", ncnn.Mat(states[0].numpy()).clone()) + ex.input("in3", ncnn.Mat(states[1].numpy()).clone()) + + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + + ret, ncnn_out1 = ex.extract("out1") + assert ret == 0, ret + + ret, ncnn_out2 = ex.extract("out2") + assert ret == 0, ret + + ret, ncnn_out3 = ex.extract("out3") + assert ret == 0, ret + + encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( + torch.int32 + ) + hx = torch.from_numpy(ncnn_out2.numpy()).clone() + cx = torch.from_numpy(ncnn_out3.numpy()).clone() + return encoder_out, encoder_out_lens, hx, cx + + def run_decoder(self, decoder_input): + assert decoder_input.dtype == torch.int32 + + with self.decoder_net.create_extractor() as ex: + ex.set_num_threads(10) + ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() + return decoder_out + + def run_joiner(self, encoder_out, decoder_out): + with self.joiner_net.create_extractor() as ex: + ex.set_num_threads(10) + ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) + ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone() + return joiner_out + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search(model: Model, encoder_out: torch.Tensor): + assert encoder_out.ndim == 2 + T = encoder_out.size(0) + + context_size = 2 + blank_id = 0 # hard-code to 0 + hyp = [blank_id] * context_size + + decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size) + + decoder_out = model.run_decoder(decoder_input).squeeze(0) + # print(decoder_out.shape) # (512,) + + for t in range(T): + encoder_out_t = encoder_out[t] + joiner_out = model.run_joiner(encoder_out_t, decoder_out) + # print(joiner_out.shape) # [500] + y = joiner_out.argmax(dim=0).tolist() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor(decoder_input, dtype=torch.int32) + decoder_out = model.run_decoder(decoder_input).squeeze(0) + return hyp[context_size:] + + +def main(): + args = get_args() + logging.info(vars(args)) + + model = Model(args) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model_filename) + + sound_file = args.sound_filename + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {sound_file}") + wave_samples = read_sound_files( + filenames=[sound_file], + expected_sample_rate=sample_rate, + )[0] + + logging.info("Decoding started") + features = fbank(wave_samples) + + num_encoder_layers = 12 + d_model = 512 + rnn_hidden_size = 1024 + + states = ( + torch.zeros(num_encoder_layers, d_model), + torch.zeros( + num_encoder_layers, + rnn_hidden_size, + ), + ) + + encoder_out, encoder_out_lens, hx, cx = model.run_encoder(features, states) + hyp = greedy_search(model, encoder_out) + logging.info(sound_file) + logging.info(sp.decode(hyp)) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py new file mode 100755 index 000000000..da78413f7 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py @@ -0,0 +1,353 @@ +#!/usr/bin/env python3 +# flake8: noqa +# +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from typing import List, Optional + +import ncnn +import sentencepiece as spm +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model-filename", + type=str, + help="Path to bpe.model", + ) + + parser.add_argument( + "--encoder-param-filename", + type=str, + help="Path to encoder.ncnn.param", + ) + + parser.add_argument( + "--encoder-bin-filename", + type=str, + help="Path to encoder.ncnn.bin", + ) + + parser.add_argument( + "--decoder-param-filename", + type=str, + help="Path to decoder.ncnn.param", + ) + + parser.add_argument( + "--decoder-bin-filename", + type=str, + help="Path to decoder.ncnn.bin", + ) + + parser.add_argument( + "--joiner-param-filename", + type=str, + help="Path to joiner.ncnn.param", + ) + + parser.add_argument( + "--joiner-bin-filename", + type=str, + help="Path to joiner.ncnn.bin", + ) + + parser.add_argument( + "sound_filename", + type=str, + help="Path to foo.wav", + ) + + return parser.parse_args() + + +class Model: + def __init__(self, args): + self.init_encoder(args) + self.init_decoder(args) + self.init_joiner(args) + + def init_encoder(self, args): + encoder_net = ncnn.Net() + encoder_net.opt.use_packing_layout = False + encoder_net.opt.use_fp16_storage = False + encoder_param = args.encoder_param_filename + encoder_model = args.encoder_bin_filename + + encoder_net.load_param(encoder_param) + encoder_net.load_model(encoder_model) + + self.encoder_net = encoder_net + + def init_decoder(self, args): + decoder_param = args.decoder_param_filename + decoder_model = args.decoder_bin_filename + + decoder_net = ncnn.Net() + decoder_net.opt.use_packing_layout = False + + decoder_net.load_param(decoder_param) + decoder_net.load_model(decoder_model) + + self.decoder_net = decoder_net + + def init_joiner(self, args): + joiner_param = args.joiner_param_filename + joiner_model = args.joiner_bin_filename + joiner_net = ncnn.Net() + joiner_net.opt.use_packing_layout = False + joiner_net.load_param(joiner_param) + joiner_net.load_model(joiner_model) + + self.joiner_net = joiner_net + + def run_encoder(self, x, states): + with self.encoder_net.create_extractor() as ex: + # ex.set_num_threads(10) + ex.input("in0", ncnn.Mat(x.numpy()).clone()) + x_lens = torch.tensor([x.size(0)], dtype=torch.float32) + ex.input("in1", ncnn.Mat(x_lens.numpy()).clone()) + ex.input("in2", ncnn.Mat(states[0].numpy()).clone()) + ex.input("in3", ncnn.Mat(states[1].numpy()).clone()) + + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + + ret, ncnn_out1 = ex.extract("out1") + assert ret == 0, ret + + ret, ncnn_out2 = ex.extract("out2") + assert ret == 0, ret + + ret, ncnn_out3 = ex.extract("out3") + assert ret == 0, ret + + encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( + torch.int32 + ) + hx = torch.from_numpy(ncnn_out2.numpy()).clone() + cx = torch.from_numpy(ncnn_out3.numpy()).clone() + return encoder_out, encoder_out_lens, hx, cx + + def run_decoder(self, decoder_input): + assert decoder_input.dtype == torch.int32 + + with self.decoder_net.create_extractor() as ex: + # ex.set_num_threads(10) + ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() + return decoder_out + + def run_joiner(self, encoder_out, decoder_out): + with self.joiner_net.create_extractor() as ex: + # ex.set_num_threads(10) + ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) + ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone() + return joiner_out + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: Model, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +): + assert encoder_out.ndim == 1 + context_size = 2 + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor( + hyp, dtype=torch.int32 + ) # (1, context_size) + decoder_out = model.run_decoder(decoder_input).squeeze(0) + + else: + assert decoder_out.ndim == 1 + assert hyp is not None, hyp + + joiner_out = model.run_joiner(encoder_out, decoder_out) + y = joiner_out.argmax(dim=0).tolist() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor(decoder_input, dtype=torch.int32) + decoder_out = model.run_decoder(decoder_input).squeeze(0) + + return hyp, decoder_out + + +def main(): + args = get_args() + logging.info(vars(args)) + + model = Model(args) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model_filename) + + sound_file = args.sound_filename + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {sound_file}") + wave_samples = read_sound_files( + filenames=[sound_file], + expected_sample_rate=sample_rate, + )[0] + logging.info(wave_samples.shape) + + num_encoder_layers = 12 + batch_size = 1 + d_model = 512 + rnn_hidden_size = 1024 + + states = ( + torch.zeros(num_encoder_layers, batch_size, d_model), + torch.zeros( + num_encoder_layers, + batch_size, + rnn_hidden_size, + ), + ) + + hyp = None + decoder_out = None + + num_processed_frames = 0 + segment = 9 + offset = 4 + + chunk = 3200 # 0.2 second + + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder( + frames, states + ) + states = (hx, cx) + hyp, decoder_out = greedy_search( + model, encoder_out.squeeze(0), decoder_out, hyp + ) + online_fbank.accept_waveform( + sampling_rate=sample_rate, waveform=torch.zeros(8000, dtype=torch.int32) + ) + + online_fbank.input_finished() + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder( + frames, states + ) + states = (hx, cx) + hyp, decoder_out = greedy_search( + model, encoder_out.squeeze(0), decoder_out, hyp + ) + + context_size = 2 + + logging.info(sound_file) + logging.info(sp.decode(hyp[context_size:])) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index b3e50c52b..9eed2dfcb 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -406,6 +406,8 @@ def get_params() -> AttributeDict: "decoder_dim": 512, # parameters for joiner "joiner_dim": 512, + # True to generate a model that can be exported via PNNX + "is_pnnx": False, # parameters for Noam "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), @@ -424,6 +426,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, aux_layer_period=params.aux_layer_period, + is_pnnx=params.is_pnnx, ) return encoder diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index 428f35796..43f5d409c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -30,6 +30,7 @@ from typing import List import torch import torch.nn as nn from scaling import ( + BasicNorm, ScaledConv1d, ScaledConv2d, ScaledEmbedding, @@ -38,6 +39,29 @@ from scaling import ( ) +class NonScaledNorm(nn.Module): + """See BasicNorm for doc""" + + def __init__( + self, + num_channels: int, + eps_exp: float, + channel_dim: int = -1, # CAUTION: see documentation. + ): + super().__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.eps_exp = eps_exp + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not torch.jit.is_tracing(): + assert x.shape[self.channel_dim] == self.num_channels + scales = ( + torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp + ).pow(-0.5) + return x * scales + + def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear: """Convert an instance of ScaledLinear to nn.Linear. @@ -174,6 +198,16 @@ def scaled_embedding_to_embedding( return embedding +def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm: + assert isinstance(basic_norm, BasicNorm), type(BasicNorm) + norm = NonScaledNorm( + num_channels=basic_norm.num_channels, + eps_exp=basic_norm.eps.data.exp().item(), + channel_dim=basic_norm.channel_dim, + ) + return norm + + def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM: """Convert an instance of ScaledLSTM to nn.LSTM. @@ -256,6 +290,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False): d[name] = scaled_conv2d_to_conv2d(m) elif isinstance(m, ScaledEmbedding): d[name] = scaled_embedding_to_embedding(m) + elif isinstance(m, BasicNorm): + d[name] = convert_basic_norm(m) elif isinstance(m, ScaledLSTM): d[name] = scaled_lstm_to_lstm(m)