Merge branch 'k2-fsa:master' into dev_zipformer_cn

This commit is contained in:
zr_jin 2023-09-22 19:18:31 +08:00 committed by GitHub
commit 023f6e05d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
156 changed files with 8961 additions and 527 deletions

View File

@ -29,6 +29,9 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
ls -lh data/fbank ls -lh data/fbank
ls -lh pruned_transducer_stateless2/exp ls -lh pruned_transducer_stateless2/exp
ln -s data/fbank/cuts_DEV.jsonl.gz data/fbank/gigaspeech_cuts_DEV.jsonl.gz
ln -s data/fbank/cuts_TEST.jsonl.gz data/fbank/gigaspeech_cuts_TEST.jsonl.gz
log "Decoding dev and test" log "Decoding dev and test"
# use a small value for decoding with CPU # use a small value for decoding with CPU

View File

@ -0,0 +1,51 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/multi_zh-hans/ASR
repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
ln -s epoch-20.pt epoch-99.pt
popd
ls -lh $repo/exp/*.pt
./zipformer/pretrained.py \
--checkpoint $repo/exp/epoch-99.pt \
--tokens $repo/data/lang_bpe_2000/tokens.txt \
--method greedy_search \
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
for method in modified_beam_search fast_beam_search; do
log "$method"
./zipformer/pretrained.py \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/epoch-99.pt \
--tokens $repo/data/lang_bpe_2000/tokens.txt \
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
done

View File

@ -45,7 +45,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -44,7 +44,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -44,7 +44,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -44,7 +44,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -44,7 +44,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -44,7 +44,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -44,7 +44,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -44,7 +44,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -0,0 +1,84 @@
# Copyright 2023 Xiaomi Corp. (author: Zengrui Jin)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-multi-zh_hans-zipformer
on:
push:
branches:
- master
pull_request:
types: [labeled]
concurrency:
group: run_multi-zh_hans_zipformer-${{ github.ref }}
cancel-in-progress: true
jobs:
run_multi-zh_hans_zipformer:
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
python-version: [3.8]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/install-kaldifeat.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-multi-zh_hans-zipformer.sh

View File

@ -34,7 +34,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -43,7 +43,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -43,7 +43,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -34,7 +34,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -34,7 +34,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -43,7 +43,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -34,7 +34,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -338,7 +338,7 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss #### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
The best results for Chinese CER(%) and English WER(%) respectivly (zh: Chinese, en: English): The best results for Chinese CER(%) and English WER(%) respectively (zh: Chinese, en: English):
|decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en | |decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en |
|--|--|--|--|--|--|--| |--|--|--|--|--|--|--|
|greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13| |greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13|

View File

@ -95,4 +95,7 @@ rst_epilog = """
.. _k2: https://github.com/k2-fsa/k2 .. _k2: https://github.com/k2-fsa/k2
.. _lhotse: https://github.com/lhotse-speech/lhotse .. _lhotse: https://github.com/lhotse-speech/lhotse
.. _yesno: https://www.openslr.org/1/ .. _yesno: https://www.openslr.org/1/
.. _Next-gen Kaldi: https://github.com/k2-fsa
.. _Kaldi: https://github.com/kaldi-asr/kaldi
.. _lilcom: https://github.com/danpovey/lilcom
""" """

View File

@ -71,9 +71,12 @@ As the initial step, let's download the pre-trained model.
.. code-block:: bash .. code-block:: bash
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp $ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ git lfs pull --include "pretrained.pt" $ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded $ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
$ cd ../data/lang_bpe_500
$ git lfs pull --include bpe.model
$ cd ../../..
To test the model, let's have a look at the decoding results **without** using LM. This can be done via the following command: To test the model, let's have a look at the decoding results **without** using LM. This can be done via the following command:

View File

@ -34,9 +34,12 @@ As the initial step, let's download the pre-trained model.
.. code-block:: bash .. code-block:: bash
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp $ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ git lfs pull --include "pretrained.pt" $ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded $ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
$ cd ../data/lang_bpe_500
$ git lfs pull --include bpe.model
$ cd ../../..
As usual, we first test the model's performance without external LM. This can be done via the following command: As usual, we first test the model's performance without external LM. This can be done via the following command:

View File

@ -32,9 +32,12 @@ As the initial step, let's download the pre-trained model.
.. code-block:: bash .. code-block:: bash
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp $ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ git lfs pull --include "pretrained.pt" $ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded $ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
$ cd ../data/lang_bpe_500
$ git lfs pull --include bpe.model
$ cd ../../..
To test the model, let's have a look at the decoding results without using LM. This can be done via the following command: To test the model, let's have a look at the decoding results without using LM. This can be done via the following command:

View File

@ -0,0 +1,180 @@
.. _dummies_tutorial_data_preparation:
Data Preparation
================
After :ref:`dummies_tutorial_environment_setup`, we can start preparing the
data for training and decoding.
The first step is to prepare the data for training. We have already provided
`prepare.sh <https://github.com/k2-fsa/icefall/blob/master/egs/yesno/ASR/prepare.sh>`_
that would prepare everything required for training.
.. code-block::
cd /tmp/icefall
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
cd egs/yesno/ASR
./prepare.sh
Note that in each recipe from `icefall`_, there exists a file ``prepare.sh``,
which you should run before you run anything else.
That is all you need for data preparation.
For the more curious
--------------------
If you are wondering how to prepare your own dataset, please refer to the following
URLs for more details:
- `<https://github.com/lhotse-speech/lhotse/tree/master/lhotse/recipes>`_
It contains recipes for a variety of dataset. If you want to add your own
dataset, please read recipes in this folder first.
- `<https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/yesno.py>`_
The `yesno`_ recipe in `lhotse`_.
If you already have a `Kaldi`_ dataset directory, which contains files like
``wav.scp``, ``feats.scp``, then you can refer to `<https://lhotse.readthedocs.io/en/latest/kaldi.html#example>`_.
A quick look to the generated files
-----------------------------------
``./prepare.sh`` puts generated files into two directories:
- ``download``
- ``data``
download
^^^^^^^^
The ``download`` directory contains downloaded dataset files:
.. code-block:: bas
tree -L 1 ./download/
./download/
|-- waves_yesno
`-- waves_yesno.tar.gz
.. hint::
Please refer to `<https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/yesno.py#L41>`_
for how the data is downloaded and extracted.
data
^^^^
.. code-block:: bash
tree ./data/
./data/
|-- fbank
| |-- yesno_cuts_test.jsonl.gz
| |-- yesno_cuts_train.jsonl.gz
| |-- yesno_feats_test.lca
| `-- yesno_feats_train.lca
|-- lang_phone
| |-- HLG.pt
| |-- L.pt
| |-- L_disambig.pt
| |-- Linv.pt
| |-- lexicon.txt
| |-- lexicon_disambig.txt
| |-- tokens.txt
| `-- words.txt
|-- lm
| |-- G.arpa
| `-- G.fst.txt
`-- manifests
|-- yesno_recordings_test.jsonl.gz
|-- yesno_recordings_train.jsonl.gz
|-- yesno_supervisions_test.jsonl.gz
`-- yesno_supervisions_train.jsonl.gz
4 directories, 18 files
**data/manifests**:
This directory contains manifests. They are used to generate files in
``data/fbank``.
To give you an idea of what it contains, we examine the first few lines of
the manifests related to the ``train`` dataset.
.. code-block:: bash
cd data/manifests
gunzip -c yesno_recordings_train.jsonl.gz | head -n 3
The output is given below:
.. code-block:: bash
{"id": "0_0_0_0_1_1_1_1", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_0_1_1_1_1.wav"}], "sampling_rate": 8000, "num_samples": 50800, "duration": 6.35, "channel_ids": [0]}
{"id": "0_0_0_1_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_1_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48880, "duration": 6.11, "channel_ids": [0]}
{"id": "0_0_1_0_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_1_0_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48160, "duration": 6.02, "channel_ids": [0]}
Please refer to `<https://github.com/lhotse-speech/lhotse/blob/master/lhotse/audio.py#L300>`_
for the meaning of each field per line.
.. code-block:: bash
gunzip -c yesno_supervisions_train.jsonl.gz | head -n 3
The output is given below:
.. code-block:: bash
{"id": "0_0_0_0_1_1_1_1", "recording_id": "0_0_0_0_1_1_1_1", "start": 0.0, "duration": 6.35, "channel": 0, "text": "NO NO NO NO YES YES YES YES", "language": "Hebrew"}
{"id": "0_0_0_1_0_1_1_0", "recording_id": "0_0_0_1_0_1_1_0", "start": 0.0, "duration": 6.11, "channel": 0, "text": "NO NO NO YES NO YES YES NO", "language": "Hebrew"}
{"id": "0_0_1_0_0_1_1_0", "recording_id": "0_0_1_0_0_1_1_0", "start": 0.0, "duration": 6.02, "channel": 0, "text": "NO NO YES NO NO YES YES NO", "language": "Hebrew"}
Please refer to `<https://github.com/lhotse-speech/lhotse/blob/master/lhotse/supervision.py#L510>`_
for the meaning of each field per line.
**data/fbank**:
This directory contains everything from ``data/manifests``. Furthermore, it also contains features
for training.
``data/fbank/yesno_feats_train.lca`` contains the features for the train dataset.
Features are compressed using `lilcom`_.
``data/fbank/yesno_cuts_train.jsonl.gz`` stores the `CutSet <https://github.com/lhotse-speech/lhotse/blob/master/lhotse/cut/set.py#L72>`_,
which stores `RecordingSet <https://github.com/lhotse-speech/lhotse/blob/master/lhotse/audio.py#L928>`_,
`SupervisionSet <https://github.com/lhotse-speech/lhotse/blob/master/lhotse/supervision.py#L510>`_,
and `FeatureSet <https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/base.py#L593>`_.
To give you an idea about what it looks like, we can run the following command:
.. code-block:: bash
cd data/fbank
gunzip -c yesno_cuts_train.jsonl.gz | head -n 3
The output is given below:
.. code-block:: bash
{"id": "0_0_0_0_1_1_1_1-0", "start": 0, "duration": 6.35, "channel": 0, "supervisions": [{"id": "0_0_0_0_1_1_1_1", "recording_id": "0_0_0_0_1_1_1_1", "start": 0.0, "duration": 6.35, "channel": 0, "text": "NO NO NO NO YES YES YES YES", "language": "Hebrew"}], "features": {"type": "kaldi-fbank", "num_frames": 635, "num_features": 23, "frame_shift": 0.01, "sampling_rate": 8000, "start": 0, "duration": 6.35, "storage_type": "lilcom_chunky", "storage_path": "data/fbank/yesno_feats_train.lca", "storage_key": "0,13000,3570", "channels": 0}, "recording": {"id": "0_0_0_0_1_1_1_1", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_0_1_1_1_1.wav"}], "sampling_rate": 8000, "num_samples": 50800, "duration": 6.35, "channel_ids": [0]}, "type": "MonoCut"}
{"id": "0_0_0_1_0_1_1_0-1", "start": 0, "duration": 6.11, "channel": 0, "supervisions": [{"id": "0_0_0_1_0_1_1_0", "recording_id": "0_0_0_1_0_1_1_0", "start": 0.0, "duration": 6.11, "channel": 0, "text": "NO NO NO YES NO YES YES NO", "language": "Hebrew"}], "features": {"type": "kaldi-fbank", "num_frames": 611, "num_features": 23, "frame_shift": 0.01, "sampling_rate": 8000, "start": 0, "duration": 6.11, "storage_type": "lilcom_chunky", "storage_path": "data/fbank/yesno_feats_train.lca", "storage_key": "16570,12964,2929", "channels": 0}, "recording": {"id": "0_0_0_1_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_1_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48880, "duration": 6.11, "channel_ids": [0]}, "type": "MonoCut"}
{"id": "0_0_1_0_0_1_1_0-2", "start": 0, "duration": 6.02, "channel": 0, "supervisions": [{"id": "0_0_1_0_0_1_1_0", "recording_id": "0_0_1_0_0_1_1_0", "start": 0.0, "duration": 6.02, "channel": 0, "text": "NO NO YES NO NO YES YES NO", "language": "Hebrew"}], "features": {"type": "kaldi-fbank", "num_frames": 602, "num_features": 23, "frame_shift": 0.01, "sampling_rate": 8000, "start": 0, "duration": 6.02, "storage_type": "lilcom_chunky", "storage_path": "data/fbank/yesno_feats_train.lca", "storage_key": "32463,12936,2696", "channels": 0}, "recording": {"id": "0_0_1_0_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_1_0_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48160, "duration": 6.02, "channel_ids": [0]}, "type": "MonoCut"}
Note that ``yesno_cuts_train.jsonl.gz`` only stores the information about how to read the features.
The actual features are stored separately in ``data/fbank/yesno_feats_train.lca``.
**data/lang**:
This directory contains the lexicon.
**data/lm**:
This directory contains language models.

View File

@ -0,0 +1,39 @@
.. _dummies_tutorial_decoding:
Decoding
========
After :ref:`dummies_tutorial_training`, we can start decoding.
The command to start the decoding is quite simple:
.. code-block:: bash
cd /tmp/icefall
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
cd egs/yesno/ASR
# We use CPU for decoding by setting the following environment variable
export CUDA_VISIBLE_DEVICES=""
./tdnn/decode.py
The output logs are given below:
.. literalinclude:: ./code/decoding-yesno.txt
For the more curious
--------------------
.. code-block:: bash
./tdnn/decode.py --help
will print the usage information about ``./tdnn/decode.py``. For instance, you
can specify:
- ``--epoch`` to use which checkpoint for decoding
- ``--avg`` to select how many checkpoints to use for model averaging
You usually try different combinations of ``--epoch`` and ``--avg`` and select
one that leads to the lowest WER (`Word Error Rate <https://en.wikipedia.org/wiki/Word_error_rate>`_).

View File

@ -0,0 +1,121 @@
.. _dummies_tutorial_environment_setup:
Environment setup
=================
We will create an environment for `Next-gen Kaldi`_ that runs on ``CPU``
in this tutorial.
.. note::
Since the `yesno`_ dataset used in this tutorial is very tiny, training on
``CPU`` works very well for it.
If your dataset is very large, e.g., hundreds or thousands of hours of
training data, please follow :ref:`install icefall` to install `icefall`_
that works with ``GPU``.
Create a virtual environment
----------------------------
.. code-block:: bash
virtualenv -p python3 /tmp/icefall_env
The above command creates a virtual environment in the directory ``/tmp/icefall_env``.
You can select any directory you want.
The output of the above command is given below:
.. code-block:: bash
Already using interpreter /usr/bin/python3
Using base prefix '/usr'
New python executable in /tmp/icefall_env/bin/python3
Also creating executable in /tmp/icefall_env/bin/python
Installing setuptools, pkg_resources, pip, wheel...done.
Now we can activate the environment using:
.. code-block:: bash
source /tmp/icefall_env/bin/activate
Install dependencies
--------------------
.. warning::
Remeber to activate your virtual environment before you continue!
After activating the virtual environment, we can use the following command
to install dependencies of `icefall`_:
.. hint::
Remeber that we will run this tutorial on ``CPU``, so we install
dependencies required only by running on ``CPU``.
.. code-block:: bash
# Caution: Installation order matters!
# We use torch 2.0.0 and torchaduio 2.0.0 in this tutorial.
# Other versions should also work.
pip install torch==2.0.0+cpu torchaudio==2.0.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
# If you are using macOS or Windows, please use the following command to install torch and torchaudio
# pip install torch==2.0.0 torchaudio==2.0.0 -f https://download.pytorch.org/whl/torch_stable.html
# Now install k2
# Please refer to https://k2-fsa.github.io/k2/installation/from_wheels.html#linux-cpu-example
pip install k2==1.24.3.dev20230726+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu.html
# Install the latest version of lhotse
pip install git+https://github.com/lhotse-speech/lhotse
Install icefall
---------------
We will put the source code of `icefall`_ into the directory ``/tmp``
You can select any directory you want.
.. code-block:: bash
cd /tmp
git clone https://github.com/k2-fsa/icefall
cd icefall
pip install -r ./requirements.txt
.. code-block:: bash
# Anytime we want to use icefall, we have to set the following
# environment variable
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
.. hint::
If you get the following error during this tutorial:
.. code-block:: bash
ModuleNotFoundError: No module named 'icefall'
please set the above environment variable to fix it.
Congratulations! You have installed `icefall`_ successfully.
For the more curious
--------------------
`icefall`_ contains a collection of Python scripts and you don't need to
use ``python3 setup.py install`` or ``pip install icefall`` to install it.
All you need to do is to download the code and set the environment variable
``PYTHONPATH``.

View File

@ -0,0 +1,34 @@
Icefall for dummies tutorial
============================
This tutorial walks you step by step about how to create a simple
ASR (`Automatic Speech Recognition <https://en.wikipedia.org/wiki/Speech_recognition>`_)
system with `Next-gen Kaldi`_.
We use the `yesno`_ dataset for demonstration. We select it out of two reasons:
- It is quite tiny, containing only about 12 minutes of data
- The training can be finished within 20 seconds on ``CPU``.
That also means you don't need a ``GPU`` to run this tutorial.
Let's get started!
Please follow items below **sequentially**.
.. note::
The :ref:`dummies_tutorial_data_preparation` runs only on Linux and on macOS.
All other parts run on Linux, macOS, and Windows.
Help from the community is appreciated to port the :ref:`dummies_tutorial_data_preparation`
to Windows.
.. toctree::
:maxdepth: 2
./environment-setup.rst
./data-preparation.rst
./training.rst
./decoding.rst
./model-export.rst

View File

@ -0,0 +1,310 @@
Model Export
============
There are three ways to export a pre-trained model.
- Export the model parameters via `model.state_dict() <https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.state_dict>`_
- Export via `torchscript <https://pytorch.org/docs/stable/jit.html>`_: either `torch.jit.script() <https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script>`_ or `torch.jit.trace() <https://pytorch.org/docs/stable/generated/torch.jit.trace.html>`_
- Export to `ONNX`_ via `torch.onnx.export() <https://pytorch.org/docs/stable/onnx.html>`_
Each method is explained below in detail.
Export the model parameters via model.state_dict()
---------------------------------------------------
The command for this kind of export is
.. code-block:: bash
cd /tmp/icefall
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
cd egs/yesno/ASR
# assume that "--epoch 14 --avg 2" produces the lowest WER.
./tdnn/export.py --epoch 14 --avg 2
The output logs are given below:
.. code-block:: bash
2023-08-16 20:42:03,912 INFO [export.py:76] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'jit': False}
2023-08-16 20:42:03,913 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
2023-08-16 20:42:03,950 INFO [export.py:93] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
2023-08-16 20:42:03,971 INFO [export.py:106] Not using torch.jit.script
2023-08-16 20:42:03,974 INFO [export.py:111] Saved to tdnn/exp/pretrained.pt
We can see from the logs that the exported model is saved to the file ``tdnn/exp/pretrained.pt``.
To give you an idea of what ``tdnn/exp/pretrained.pt`` contains, we can use the following command:
.. code-block:: python3
>>> import torch
>>> m = torch.load("tdnn/exp/pretrained.pt")
>>> list(m.keys())
['model']
>>> list(m["model"].keys())
['tdnn.0.weight', 'tdnn.0.bias', 'tdnn.2.running_mean', 'tdnn.2.running_var', 'tdnn.2.num_batches_tracked', 'tdnn.3.weight', 'tdnn.3.bias', 'tdnn.5.running_mean', 'tdnn.5.running_var', 'tdnn.5.num_batches_tracked', 'tdnn.6.weight', 'tdnn.6.bias', 'tdnn.8.running_mean', 'tdnn.8.running_var', 'tdnn.8.num_batches_tracked', 'output_linear.weight', 'output_linear.bias']
We can use ``tdnn/exp/pretrained.pt`` in the following way with ``./tdnn/decode.py``:
.. code-block:: bash
cd tdnn/exp
ln -s pretrained.pt epoch-99.pt
cd ../..
./tdnn/decode.py --epoch 99 --avg 1
The output logs of the above command are given below:
.. code-block:: bash
2023-08-16 20:45:48,089 INFO [decode.py:262] Decoding started
2023-08-16 20:45:48,090 INFO [decode.py:263] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'feature_dim': 23, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 99, 'avg': 1, 'export': False, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': False, 'k2-git-sha1': 'ad79f1c699c684de9785ed6ca5edb805a41f78c3', 'k2-git-date': 'Wed Jul 26 09:30:42 2023', 'lhotse-version': '1.16.0.dev+git.aa073f6.clean', 'torch-version': '2.0.0', 'torch-cuda-available': False, 'torch-cuda-version': None, 'python-version': '3.1', 'icefall-git-branch': 'master', 'icefall-git-sha1': '9a47c08-clean', 'icefall-git-date': 'Mon Aug 14 22:10:50 2023', 'icefall-path': '/private/tmp/icefall', 'k2-path': '/private/tmp/icefall_env/lib/python3.11/site-packages/k2/__init__.py', 'lhotse-path': '/private/tmp/icefall_env/lib/python3.11/site-packages/lhotse/__init__.py', 'hostname': 'fangjuns-MacBook-Pro.local', 'IP address': '127.0.0.1'}}
2023-08-16 20:45:48,092 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
2023-08-16 20:45:48,103 INFO [decode.py:272] device: cpu
2023-08-16 20:45:48,109 INFO [checkpoint.py:112] Loading checkpoint from tdnn/exp/epoch-99.pt
2023-08-16 20:45:48,115 INFO [asr_datamodule.py:218] About to get test cuts
2023-08-16 20:45:48,115 INFO [asr_datamodule.py:253] About to get test cuts
2023-08-16 20:45:50,386 INFO [decode.py:203] batch 0/?, cuts processed until now is 4
2023-08-16 20:45:50,556 INFO [decode.py:240] The transcripts are stored in tdnn/exp/recogs-test_set.txt
2023-08-16 20:45:50,557 INFO [utils.py:564] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
2023-08-16 20:45:50,558 INFO [decode.py:248] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
2023-08-16 20:45:50,559 INFO [decode.py:315] Done!
We can see that it produces an identical WER as before.
We can also use it to decode files with the following command:
.. code-block:: bash
# ./tdnn/pretrained.py requires kaldifeat
#
# Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html
# for how to install kaldifeat
pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
./tdnn/pretrained.py \
--checkpoint ./tdnn/exp/pretrained.pt \
--HLG ./data/lang_phone/HLG.pt \
--words-file ./data/lang_phone/words.txt \
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
download/waves_yesno/0_0_1_0_0_0_1_0.wav
The output is given below:
.. code-block:: bash
2023-08-16 20:53:19,208 INFO [pretrained.py:136] {'feature_dim': 23, 'num_classes': 4, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'checkpoint': './tdnn/exp/pretrained.pt', 'words_file': './data/lang_phone/words.txt', 'HLG': './data/lang_phone/HLG.pt', 'sound_files': ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']}
2023-08-16 20:53:19,208 INFO [pretrained.py:142] device: cpu
2023-08-16 20:53:19,208 INFO [pretrained.py:144] Creating model
2023-08-16 20:53:19,212 INFO [pretrained.py:156] Loading HLG from ./data/lang_phone/HLG.pt
2023-08-16 20:53:19,213 INFO [pretrained.py:160] Constructing Fbank computer
2023-08-16 20:53:19,213 INFO [pretrained.py:170] Reading sound files: ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']
2023-08-16 20:53:19,224 INFO [pretrained.py:176] Decoding started
2023-08-16 20:53:19,304 INFO [pretrained.py:212]
download/waves_yesno/0_0_0_1_0_0_0_1.wav:
NO NO NO YES NO NO NO YES
download/waves_yesno/0_0_1_0_0_0_1_0.wav:
NO NO YES NO NO NO YES NO
2023-08-16 20:53:19,304 INFO [pretrained.py:214] Decoding Done
Export via torch.jit.script()
-----------------------------
The command for this kind of export is
.. code-block:: bash
cd /tmp/icefall
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
cd egs/yesno/ASR
# assume that "--epoch 14 --avg 2" produces the lowest WER.
./tdnn/export.py --epoch 14 --avg 2 --jit true
The output logs are given below:
.. code-block:: bash
2023-08-16 20:47:44,666 INFO [export.py:76] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'jit': True}
2023-08-16 20:47:44,667 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
2023-08-16 20:47:44,670 INFO [export.py:93] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
2023-08-16 20:47:44,677 INFO [export.py:100] Using torch.jit.script
2023-08-16 20:47:44,843 INFO [export.py:104] Saved to tdnn/exp/cpu_jit.pt
From the output logs we can see that the generated file is saved to ``tdnn/exp/cpu_jit.pt``.
Don't be confused by the name ``cpu_jit.pt``. The ``cpu`` part means the model is moved to
CPU before exporting. That means, when you load it with:
.. code-block:: bash
torch.jit.load()
you don't need to specify the argument `map_location <https://pytorch.org/docs/stable/generated/torch.jit.load.html#torch.jit.load>`_
and it resides on CPU by default.
To use ``tdnn/exp/cpu_jit.pt`` with `icefall`_ to decode files, we can use:
.. code-block:: bash
# ./tdnn/jit_pretrained.py requires kaldifeat
#
# Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html
# for how to install kaldifeat
pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
./tdnn/jit_pretrained.py \
--nn-model ./tdnn/exp/cpu_jit.pt \
--HLG ./data/lang_phone/HLG.pt \
--words-file ./data/lang_phone/words.txt \
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
download/waves_yesno/0_0_1_0_0_0_1_0.wav
The output is given below:
.. code-block:: bash
2023-08-16 20:56:00,603 INFO [jit_pretrained.py:121] {'feature_dim': 23, 'num_classes': 4, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'nn_model': './tdnn/exp/cpu_jit.pt', 'words_file': './data/lang_phone/words.txt', 'HLG': './data/lang_phone/HLG.pt', 'sound_files': ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']}
2023-08-16 20:56:00,603 INFO [jit_pretrained.py:127] device: cpu
2023-08-16 20:56:00,603 INFO [jit_pretrained.py:129] Loading torchscript model
2023-08-16 20:56:00,640 INFO [jit_pretrained.py:134] Loading HLG from ./data/lang_phone/HLG.pt
2023-08-16 20:56:00,641 INFO [jit_pretrained.py:138] Constructing Fbank computer
2023-08-16 20:56:00,641 INFO [jit_pretrained.py:148] Reading sound files: ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']
2023-08-16 20:56:00,642 INFO [jit_pretrained.py:154] Decoding started
2023-08-16 20:56:00,727 INFO [jit_pretrained.py:190]
download/waves_yesno/0_0_0_1_0_0_0_1.wav:
NO NO NO YES NO NO NO YES
download/waves_yesno/0_0_1_0_0_0_1_0.wav:
NO NO YES NO NO NO YES NO
2023-08-16 20:56:00,727 INFO [jit_pretrained.py:192] Decoding Done
.. hint::
We provide only code for ``torch.jit.script()``. You can try ``torch.jit.trace()``
if you want.
Export via torch.onnx.export()
------------------------------
The command for this kind of export is
.. code-block:: bash
cd /tmp/icefall
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
cd egs/yesno/ASR
# tdnn/export_onnx.py requires onnx and onnxruntime
pip install onnx onnxruntime
# assume that "--epoch 14 --avg 2" produces the lowest WER.
./tdnn/export_onnx.py \
--epoch 14 \
--avg 2
The output logs are given below:
.. code-block:: bash
2023-08-16 20:59:20,888 INFO [export_onnx.py:83] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'epoch': 14, 'avg': 2}
2023-08-16 20:59:20,888 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
2023-08-16 20:59:20,892 INFO [export_onnx.py:100] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
================ Diagnostic Run torch.onnx.export version 2.0.0 ================
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
2023-08-16 20:59:21,047 INFO [export_onnx.py:127] Saved to tdnn/exp/model-epoch-14-avg-2.onnx
2023-08-16 20:59:21,047 INFO [export_onnx.py:136] meta_data: {'model_type': 'tdnn', 'version': '1', 'model_author': 'k2-fsa', 'comment': 'non-streaming tdnn for the yesno recipe', 'vocab_size': 4}
2023-08-16 20:59:21,049 INFO [export_onnx.py:140] Generate int8 quantization models
2023-08-16 20:59:21,075 INFO [onnx_quantizer.py:538] Quantization parameters for tensor:"/Transpose_1_output_0" not specified
2023-08-16 20:59:21,081 INFO [export_onnx.py:151] Saved to tdnn/exp/model-epoch-14-avg-2.int8.onnx
We can see from the logs that it generates two files:
- ``tdnn/exp/model-epoch-14-avg-2.onnx`` (ONNX model with ``float32`` weights)
- ``tdnn/exp/model-epoch-14-avg-2.int8.onnx`` (ONNX model with ``int8`` weights)
To use the generated ONNX model files for decoding with `onnxruntime`_, we can use
.. code-block:: bash
# ./tdnn/onnx_pretrained.py requires kaldifeat
#
# Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html
# for how to install kaldifeat
pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
./tdnn/onnx_pretrained.py \
--nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \
--HLG ./data/lang_phone/HLG.pt \
--words-file ./data/lang_phone/words.txt \
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
download/waves_yesno/0_0_1_0_0_0_1_0.wav
The output is given below:
.. code-block:: bash
2023-08-16 21:03:24,260 INFO [onnx_pretrained.py:166] {'feature_dim': 23, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'nn_model': './tdnn/exp/model-epoch-14-avg-2.onnx', 'words_file': './data/lang_phone/words.txt', 'HLG': './data/lang_phone/HLG.pt', 'sound_files': ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']}
2023-08-16 21:03:24,260 INFO [onnx_pretrained.py:171] device: cpu
2023-08-16 21:03:24,260 INFO [onnx_pretrained.py:173] Loading onnx model ./tdnn/exp/model-epoch-14-avg-2.onnx
2023-08-16 21:03:24,267 INFO [onnx_pretrained.py:176] Loading HLG from ./data/lang_phone/HLG.pt
2023-08-16 21:03:24,270 INFO [onnx_pretrained.py:180] Constructing Fbank computer
2023-08-16 21:03:24,273 INFO [onnx_pretrained.py:190] Reading sound files: ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']
2023-08-16 21:03:24,279 INFO [onnx_pretrained.py:196] Decoding started
2023-08-16 21:03:24,318 INFO [onnx_pretrained.py:232]
download/waves_yesno/0_0_0_1_0_0_0_1.wav:
NO NO NO YES NO NO NO YES
download/waves_yesno/0_0_1_0_0_0_1_0.wav:
NO NO YES NO NO NO YES NO
2023-08-16 21:03:24,318 INFO [onnx_pretrained.py:234] Decoding Done
.. note::
To use the ``int8`` ONNX model for decoding, please use:
.. code-block:: bash
./tdnn/onnx_pretrained.py \
--nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \
--HLG ./data/lang_phone/HLG.pt \
--words-file ./data/lang_phone/words.txt \
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
download/waves_yesno/0_0_1_0_0_0_1_0.wav
For the more curious
--------------------
If you are wondering how to deploy the model without ``torch``, please
continue reading. We will show how to use `sherpa-onnx`_ to run the
exported ONNX models, which depends only on `onnxruntime`_ and does not
depend on ``torch``.
In this tutorial, we will only demonstrate the usage of `sherpa-onnx`_ with the
pre-trained model of the `yesno`_ recipe. There are also other two frameworks
available:
- `sherpa`_. It works with torchscript models.
- `sherpa-ncnn`_. It works with models exported using :ref:`icefall_export_to_ncnn` with `ncnn`_
Please see `<https://k2-fsa.github.io/sherpa/>`_ for further details.

View File

@ -0,0 +1,39 @@
.. _dummies_tutorial_training:
Training
========
After :ref:`dummies_tutorial_data_preparation`, we can start training.
The command to start the training is quite simple:
.. code-block:: bash
cd /tmp/icefall
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
cd egs/yesno/ASR
# We use CPU for training by setting the following environment variable
export CUDA_VISIBLE_DEVICES=""
./tdnn/train.py
That's it!
You can find the training logs below:
.. literalinclude:: ./code/train-yesno.txt
For the more curious
--------------------
.. code-block:: bash
./tdnn/train.py --help
will print the usage information about ``./tdnn/train.py``. For instance, you
can specify the number of epochs to train and the location to save the training
results.
The training text logs are saved in ``tdnn/exp/log`` while the tensorboard
logs are in ``tdnn/exp/tensorboard``.

View File

@ -20,6 +20,7 @@ speech recognition recipes using `k2 <https://github.com/k2-fsa/k2>`_.
:maxdepth: 2 :maxdepth: 2
:caption: Contents: :caption: Contents:
for-dummies/index.rst
installation/index installation/index
docker/index docker/index
faqs faqs

View File

@ -41,7 +41,7 @@ as an example.
./pruned_transducer_stateless3/export.py \ ./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -78,7 +78,7 @@ In each recipe, there is also a file ``pretrained.py``, which can use
./pruned_transducer_stateless3/pretrained.py \ ./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/pretrained-iter-1224000-avg-14.pt \ --checkpoint ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/pretrained-iter-1224000-avg-14.pt \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/bpe.model \ --tokens ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
--method greedy_search \ --method greedy_search \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav \

View File

@ -153,11 +153,10 @@ Next, we use the following code to export our model:
./conv_emformer_transducer_stateless2/export-for-ncnn.py \ ./conv_emformer_transducer_stateless2/export-for-ncnn.py \
--exp-dir $dir/exp \ --exp-dir $dir/exp \
--bpe-model $dir/data/lang_bpe_500/bpe.model \ --tokens $dir/data/lang_bpe_500/tokens.txt \
--epoch 30 \ --epoch 30 \
--avg 1 \ --avg 1 \
--use-averaged-model 0 \ --use-averaged-model 0 \
\
--num-encoder-layers 12 \ --num-encoder-layers 12 \
--chunk-length 32 \ --chunk-length 32 \
--cnn-module-kernel 31 \ --cnn-module-kernel 31 \

View File

@ -73,7 +73,7 @@ Next, we use the following code to export our model:
./lstm_transducer_stateless2/export-for-ncnn.py \ ./lstm_transducer_stateless2/export-for-ncnn.py \
--exp-dir $dir/exp \ --exp-dir $dir/exp \
--bpe-model $dir/data/lang_bpe_500/bpe.model \ --tokens $dir/data/lang_bpe_500/tokens.txt \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
--use-averaged-model 0 \ --use-averaged-model 0 \

View File

@ -72,12 +72,11 @@ Next, we use the following code to export our model:
dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
--bpe-model $dir/data/lang_bpe_500/bpe.model \ --tokens $dir/data/lang_bpe_500/tokens.txt \
--exp-dir $dir/exp \ --exp-dir $dir/exp \
--use-averaged-model 0 \ --use-averaged-model 0 \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
\
--decode-chunk-len 32 \ --decode-chunk-len 32 \
--num-left-chunks 4 \ --num-left-chunks 4 \
--num-encoder-layers "2,4,3,2,4" \ --num-encoder-layers "2,4,3,2,4" \

View File

@ -71,7 +71,7 @@ Export the model to ONNX
.. code-block:: bash .. code-block:: bash
./pruned_transducer_stateless7_streaming/export-onnx.py \ ./pruned_transducer_stateless7_streaming/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \ --use-averaged-model 0 \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \

View File

@ -32,7 +32,7 @@ as an example in the following.
./pruned_transducer_stateless3/export.py \ ./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch $epoch \ --epoch $epoch \
--avg $avg \ --avg $avg \
--jit 1 --jit 1

View File

@ -33,7 +33,7 @@ as an example in the following.
./lstm_transducer_stateless2/export.py \ ./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \ --exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--iter $iter \ --iter $iter \
--avg $avg \ --avg $avg \
--jit-trace 1 --jit-trace 1

View File

@ -37,7 +37,7 @@ from lhotse.dataset import (
DynamicBucketingSampler, DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
@ -292,8 +292,8 @@ class Aidatatang_200zhAsrDataModule:
buffer_size=50000, buffer_size=50000,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SimpleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = SimpleCutSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,

View File

@ -322,6 +322,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64) y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export( torch.onnx.export(
decoder_model, decoder_model,
y, y,

View File

@ -151,12 +151,14 @@ class OnnxModel:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
encoder_model_filename, encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
def init_decoder(self, decoder_model_filename: str): def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
decoder_model_filename, decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -170,6 +172,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
joiner_model_filename, joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -30,7 +30,7 @@ from lhotse.dataset import (
DynamicBucketingSampler, DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
@ -278,8 +278,8 @@ class AishellAsrDataModule:
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SimpleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = SimpleCutSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,

View File

@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
DynamicBucketingSampler, DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
@ -299,8 +299,8 @@ class AiShell2AsrDataModule:
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SimpleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = SimpleCutSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,

View File

@ -30,7 +30,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
DynamicBucketingSampler, DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import ( # noqa F401 for AudioSamples from lhotse.dataset.input_strategies import ( # noqa F401 for AudioSamples
@ -310,8 +310,8 @@ class Aishell4AsrDataModule:
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SimpleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = SimpleCutSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,

View File

@ -37,7 +37,7 @@ from lhotse.dataset import (
DynamicBucketingSampler, DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
@ -292,8 +292,8 @@ class AlimeetingAsrDataModule:
drop_last=True, drop_last=True,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SimpleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = SimpleCutSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,

View File

@ -257,7 +257,7 @@ class AmiAsrDataModule:
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler( train_sampler = SimpleCutSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,

View File

@ -30,7 +30,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
DynamicBucketingSampler, DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
@ -311,8 +311,8 @@ class CommonVoiceAsrDataModule:
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SimpleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = SimpleCutSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,

View File

@ -330,6 +330,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64) y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export( torch.onnx.export(
decoder_model, decoder_model,
y, y,

View File

@ -152,12 +152,14 @@ class OnnxModel:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
encoder_model_filename, encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
def init_decoder(self, decoder_model_filename: str): def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
decoder_model_filename, decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -171,6 +173,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
joiner_model_filename, joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
DynamicBucketingSampler, DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
@ -339,8 +339,8 @@ class CSJAsrDataModule:
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SimpleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = SimpleCutSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,

View File

@ -27,7 +27,7 @@ from lhotse.dataset import (
DynamicBucketingSampler, DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
@ -264,8 +264,8 @@ class GigaSpeechAsrDataModule:
drop_last=True, drop_last=True,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SimpleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = SimpleCutSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,

View File

@ -30,7 +30,7 @@ from lhotse.dataset import (
DynamicBucketingSampler, DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
@ -297,8 +297,8 @@ class GigaSpeechAsrDataModule:
drop_last=True, drop_last=True,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SimpleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = SimpleCutSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,

View File

@ -259,7 +259,7 @@ class LibriCssAsrDataModule:
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler( train_sampler = SimpleCutSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,

View File

@ -1 +0,0 @@
../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py

File diff suppressed because it is too large Load Diff

View File

@ -79,7 +79,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
# ln -sfv /path/to/rirs_noises $dl_dir/ # ln -sfv /path/to/rirs_noises $dl_dir/
# #
if [ ! -d $dl_dir/rirs_noises ]; then if [ ! -d $dl_dir/rirs_noises ]; then
lhotse download rirs_noises $dl_dir lhotse download rir-noise $dl_dir/rirs_noises
fi fi
fi fi
@ -89,6 +89,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
# to $dl_dir/librispeech. We perform text normalization for the transcripts. # to $dl_dir/librispeech. We perform text normalization for the transcripts.
# NOTE: Alignments are required for this recipe. # NOTE: Alignments are required for this recipe.
mkdir -p data/manifests mkdir -p data/manifests
lhotse prepare librispeech -p train-clean-100 -p train-clean-360 -p train-other-500 -p dev-clean \ lhotse prepare librispeech -p train-clean-100 -p train-clean-360 -p train-other-500 -p dev-clean \
-j 4 --alignments-dir $dl_dir/libri_alignments/LibriSpeech $dl_dir/librispeech data/manifests/ -j 4 --alignments-dir $dl_dir/libri_alignments/LibriSpeech $dl_dir/librispeech data/manifests/
fi fi
@ -112,7 +113,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
# We assume that you have downloaded the RIRS_NOISES corpus # We assume that you have downloaded the RIRS_NOISES corpus
# to $dl_dir/rirs_noises # to $dl_dir/rirs_noises
lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises data/manifests lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises/RIRS_NOISES data/manifests
fi fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then

View File

@ -401,6 +401,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64) y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export( torch.onnx.export(
decoder_model, decoder_model,
y, y,

View File

@ -136,6 +136,7 @@ class OnnxModel:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
encoder_model_filename, encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
self.init_encoder_states() self.init_encoder_states()
@ -184,6 +185,7 @@ class OnnxModel:
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
decoder_model_filename, decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -197,6 +199,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
joiner_model_filename, joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -359,6 +359,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64) y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export( torch.onnx.export(
decoder_model, decoder_model,
y, y,

View File

@ -356,6 +356,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64) y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export( torch.onnx.export(
decoder_model, decoder_model,
y, y,

View File

@ -129,6 +129,7 @@ class OnnxModel:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
encoder_model_filename, encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
self.init_encoder_states() self.init_encoder_states()
@ -166,6 +167,7 @@ class OnnxModel:
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
decoder_model_filename, decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -179,6 +181,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
joiner_model_filename, joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -172,30 +172,35 @@ class Model:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
args.encoder_model_filename, args.encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
def init_decoder(self, args): def init_decoder(self, args):
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
args.decoder_model_filename, args.decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
def init_joiner(self, args): def init_joiner(self, args):
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
args.joiner_model_filename, args.joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
def init_joiner_encoder_proj(self, args): def init_joiner_encoder_proj(self, args):
self.joiner_encoder_proj = ort.InferenceSession( self.joiner_encoder_proj = ort.InferenceSession(
args.joiner_encoder_proj_model_filename, args.joiner_encoder_proj_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
def init_joiner_decoder_proj(self, args): def init_joiner_decoder_proj(self, args):
self.joiner_decoder_proj = ort.InferenceSession( self.joiner_decoder_proj = ort.InferenceSession(
args.joiner_decoder_proj_model_filename, args.joiner_decoder_proj_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

View File

@ -31,7 +31,7 @@ from lhotse.dataset import (
CutMix, CutMix,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
@ -290,8 +290,8 @@ class LibriSpeechAsrDataModule:
drop_last=True, drop_last=True,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SimpleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = SimpleCutSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,

View File

@ -307,6 +307,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64) y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export( torch.onnx.export(
decoder_model, decoder_model,
y, y,

View File

@ -1008,7 +1008,7 @@ def modified_beam_search(
for i in range(N): for i in range(N):
B[i].add( B[i].add(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[-1] * (context_size - 1) + [blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
context_state=None if context_graph is None else context_graph.root, context_state=None if context_graph is None else context_graph.root,
timestamp=[], timestamp=[],
@ -1217,7 +1217,7 @@ def modified_beam_search_lm_rescore(
for i in range(N): for i in range(N):
B[i].add( B[i].add(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[-1] * (context_size - 1) + [blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
timestamp=[], timestamp=[],
) )
@ -1417,7 +1417,7 @@ def modified_beam_search_lm_rescore_LODR(
for i in range(N): for i in range(N):
B[i].add( B[i].add(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[-1] * (context_size - 1) + [blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
timestamp=[], timestamp=[],
) )
@ -1617,7 +1617,7 @@ def _deprecated_modified_beam_search(
B = HypothesisList() B = HypothesisList()
B.add( B.add(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[-1] * (context_size - 1) + [blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
timestamp=[], timestamp=[],
) )
@ -1753,7 +1753,11 @@ def beam_search(
t = 0 t = 0
B = HypothesisList() B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[])) B.add(
Hypothesis(
ys=[-1] * (context_size - 1) + [blank_id], log_prob=0.0, timestamp=[]
)
)
max_sym_per_utt = 20000 max_sym_per_utt = 20000
@ -2265,7 +2269,7 @@ def modified_beam_search_ngram_rescoring(
for i in range(N): for i in range(N):
B[i].add( B[i].add(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[-1] * (context_size - 1) + [blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state_cost=NgramLmStateCost(ngram_lm), state_cost=NgramLmStateCost(ngram_lm),
) )
@ -2385,6 +2389,7 @@ def modified_beam_search_LODR(
LODR_lm_scale: float, LODR_lm_scale: float,
LM: LmScorer, LM: LmScorer,
beam: int = 4, beam: int = 4,
context_graph: Optional[ContextGraph] = None,
) -> List[List[int]]: ) -> List[List[int]]:
"""This function implements LODR (https://arxiv.org/abs/2203.16776) with """This function implements LODR (https://arxiv.org/abs/2203.16776) with
`modified_beam_search`. It uses a bi-gram language model as the estimate `modified_beam_search`. It uses a bi-gram language model as the estimate
@ -2446,13 +2451,14 @@ def modified_beam_search_LODR(
for i in range(N): for i in range(N):
B[i].add( B[i].add(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[-1] * (context_size - 1) + [blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state=init_states, # state of the NN LM state=init_states, # state of the NN LM
lm_score=init_score.reshape(-1), lm_score=init_score.reshape(-1),
state_cost=NgramLmStateCost( state_cost=NgramLmStateCost(
LODR_lm LODR_lm
), # state of the source domain ngram ), # state of the source domain ngram
context_state=None if context_graph is None else context_graph.root,
) )
) )
@ -2598,8 +2604,17 @@ def modified_beam_search_LODR(
hyp_log_prob = topk_log_probs[k] # get score of current hyp hyp_log_prob = topk_log_probs[k] # get score of current hyp
new_token = topk_token_indexes[k] new_token = topk_token_indexes[k]
context_score = 0
new_context_state = None if context_graph is None else hyp.context_state
if new_token not in (blank_id, unk_id): if new_token not in (blank_id, unk_id):
if context_graph is not None:
(
context_score,
new_context_state,
) = context_graph.forward_one_step(hyp.context_state, new_token)
ys.append(new_token) ys.append(new_token)
state_cost = hyp.state_cost.forward_one_step(new_token) state_cost = hyp.state_cost.forward_one_step(new_token)
@ -2615,6 +2630,7 @@ def modified_beam_search_LODR(
hyp_log_prob += ( hyp_log_prob += (
lm_score[new_token] * lm_scale lm_score[new_token] * lm_scale
+ LODR_lm_scale * current_ngram_score + LODR_lm_scale * current_ngram_score
+ context_score
) # add the lm score ) # add the lm score
lm_score = scores[count] lm_score = scores[count]
@ -2633,10 +2649,31 @@ def modified_beam_search_LODR(
state=state, state=state,
lm_score=lm_score, lm_score=lm_score,
state_cost=state_cost, state_cost=state_cost,
context_state=new_context_state,
) )
B[i].add(new_hyp) B[i].add(new_hyp)
B = B + finalized_B B = B + finalized_B
# finalize context_state, if the matched contexts do not reach final state
# we need to add the score on the corresponding backoff arc
if context_graph is not None:
finalized_B = [HypothesisList() for _ in range(len(B))]
for i, hyps in enumerate(B):
for hyp in list(hyps):
context_score, new_context_state = context_graph.finalize(
hyp.context_state
)
finalized_B[i].add(
Hypothesis(
ys=hyp.ys,
log_prob=hyp.log_prob + context_score,
timestamp=hyp.timestamp,
context_state=new_context_state,
)
)
B = finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B] best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps] sorted_ans = [h.ys[context_size:] for h in best_hyps]
@ -2709,7 +2746,7 @@ def modified_beam_search_lm_shallow_fusion(
for i in range(N): for i in range(N):
B[i].add( B[i].add(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[-1] * (context_size - 1) + [blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state=init_states, state=init_states,
lm_score=init_score.reshape(-1), lm_score=init_score.reshape(-1),

View File

@ -312,6 +312,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64) y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export( torch.onnx.export(
decoder_model, decoder_model,
y, y,

View File

@ -150,12 +150,14 @@ class OnnxModel:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
encoder_model_filename, encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
def init_decoder(self, decoder_model_filename: str): def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
decoder_model_filename, decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -169,6 +171,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
joiner_model_filename, joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -78,6 +78,7 @@ def test_conv2d_subsampling():
session = ort.InferenceSession( session = ort.InferenceSession(
filename, filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
input_nodes = session.get_inputs() input_nodes = session.get_inputs()
@ -133,6 +134,7 @@ def test_rel_pos():
session = ort.InferenceSession( session = ort.InferenceSession(
filename, filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
input_nodes = session.get_inputs() input_nodes = session.get_inputs()
@ -220,6 +222,7 @@ def test_conformer_encoder_layer():
session = ort.InferenceSession( session = ort.InferenceSession(
filename, filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
input_nodes = session.get_inputs() input_nodes = session.get_inputs()
@ -304,6 +307,7 @@ def test_conformer_encoder():
session = ort.InferenceSession( session = ort.InferenceSession(
filename, filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
input_nodes = session.get_inputs() input_nodes = session.get_inputs()
@ -359,6 +363,7 @@ def test_conformer():
session = ort.InferenceSession( session = ort.InferenceSession(
filename, filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
input_nodes = session.get_inputs() input_nodes = session.get_inputs()

View File

@ -404,6 +404,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64) y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export( torch.onnx.export(
decoder_model, decoder_model,
y, y,

View File

@ -335,6 +335,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64) y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export( torch.onnx.export(
decoder_model, decoder_model,
y, y,

View File

@ -138,6 +138,7 @@ class OnnxModel:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
encoder_model_filename, encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
self.init_encoder_states() self.init_encoder_states()
@ -185,6 +186,7 @@ class OnnxModel:
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
decoder_model_filename, decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -198,6 +200,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
joiner_model_filename, joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -26,7 +26,7 @@ You can generate the checkpoint with the following command:
./pruned_transducer_stateless7/export.py \ ./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless7/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \ --epoch 30 \
--avg 9 --avg 9
@ -52,12 +52,12 @@ import torch
import torch.nn as nn import torch.nn as nn
from alignment import batch_force_alignment from alignment import batch_force_alignment
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import AttributeDict, convert_timestamp, parse_timestamp
from lhotse import CutSet from lhotse import CutSet
from lhotse.serialization import SequentialJsonlWriter from lhotse.serialization import SequentialJsonlWriter
from lhotse.supervision import AlignmentItem from lhotse.supervision import AlignmentItem
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import AttributeDict, convert_timestamp, parse_timestamp
def get_parser(): def get_parser():

View File

@ -71,6 +71,10 @@ class Decoder(nn.Module):
groups=decoder_dim // 4, # group size == 4 groups=decoder_dim // 4, # group size == 4
bias=False, bias=False,
) )
else:
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
# when inference with torch.jit.script and context_size == 1
self.conv = nn.Identity()
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
""" """

View File

@ -329,6 +329,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64) y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export( torch.onnx.export(
decoder_model, decoder_model,
y, y,

View File

@ -30,7 +30,7 @@ from lhotse.dataset import (
DynamicBucketingSampler, DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
@ -297,8 +297,8 @@ class GigaSpeechAsrDataModule:
drop_last=True, drop_last=True,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SimpleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = SimpleCutSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,

View File

@ -74,6 +74,7 @@ def test_conv2d_subsampling():
session = ort.InferenceSession( session = ort.InferenceSession(
filename, filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
input_nodes = session.get_inputs() input_nodes = session.get_inputs()
@ -128,6 +129,7 @@ def test_rel_pos():
session = ort.InferenceSession( session = ort.InferenceSession(
filename, filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
input_nodes = session.get_inputs() input_nodes = session.get_inputs()
@ -204,6 +206,7 @@ def test_zipformer_encoder_layer():
session = ort.InferenceSession( session = ort.InferenceSession(
filename, filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
input_nodes = session.get_inputs() input_nodes = session.get_inputs()
@ -284,6 +287,7 @@ def test_zipformer_encoder():
session = ort.InferenceSession( session = ort.InferenceSession(
filename, filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
input_nodes = session.get_inputs() input_nodes = session.get_inputs()
@ -338,6 +342,7 @@ def test_zipformer():
session = ort.InferenceSession( session = ort.InferenceSession(
filename, filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
input_nodes = session.get_inputs() input_nodes = session.get_inputs()

View File

@ -326,41 +326,49 @@ def main():
encoder = ort.InferenceSession( encoder = ort.InferenceSession(
args.encoder_model_filename, args.encoder_model_filename,
sess_options=session_opts, sess_options=session_opts,
providers=["CPUExecutionProvider"],
) )
decoder = ort.InferenceSession( decoder = ort.InferenceSession(
args.decoder_model_filename, args.decoder_model_filename,
sess_options=session_opts, sess_options=session_opts,
providers=["CPUExecutionProvider"],
) )
joiner = ort.InferenceSession( joiner = ort.InferenceSession(
args.joiner_model_filename, args.joiner_model_filename,
sess_options=session_opts, sess_options=session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_encoder_proj = ort.InferenceSession( joiner_encoder_proj = ort.InferenceSession(
args.joiner_encoder_proj_model_filename, args.joiner_encoder_proj_model_filename,
sess_options=session_opts, sess_options=session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_decoder_proj = ort.InferenceSession( joiner_decoder_proj = ort.InferenceSession(
args.joiner_decoder_proj_model_filename, args.joiner_decoder_proj_model_filename,
sess_options=session_opts, sess_options=session_opts,
providers=["CPUExecutionProvider"],
) )
lconv = ort.InferenceSession( lconv = ort.InferenceSession(
args.lconv_filename, args.lconv_filename,
sess_options=session_opts, sess_options=session_opts,
providers=["CPUExecutionProvider"],
) )
frame_reducer = ort.InferenceSession( frame_reducer = ort.InferenceSession(
args.frame_reducer_filename, args.frame_reducer_filename,
sess_options=session_opts, sess_options=session_opts,
providers=["CPUExecutionProvider"],
) )
ctc_output = ort.InferenceSession( ctc_output = ort.InferenceSession(
args.ctc_output_filename, args.ctc_output_filename,
sess_options=session_opts, sess_options=session_opts,
providers=["CPUExecutionProvider"],
) )
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()

View File

@ -413,6 +413,7 @@ def export_decoder_model_onnx(
context_size = decoder_model.decoder.context_size context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64) y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export( torch.onnx.export(
decoder_model, decoder_model,
y, y,

View File

@ -401,6 +401,7 @@ def export_decoder_model_onnx(
context_size = decoder_model.decoder.context_size context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64) y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export( torch.onnx.export(
decoder_model, decoder_model,
y, y,

View File

@ -130,6 +130,7 @@ class OnnxModel:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
encoder_model_filename, encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
self.init_encoder_states() self.init_encoder_states()
@ -229,6 +230,7 @@ class OnnxModel:
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
decoder_model_filename, decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -242,6 +244,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
joiner_model_filename, joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -865,7 +865,7 @@ class ZipformerEncoderLayer(nn.Module):
return final_dropout_rate return final_dropout_rate
else: else:
return initial_dropout_rate - ( return initial_dropout_rate - (
initial_dropout_rate * final_dropout_rate initial_dropout_rate - final_dropout_rate
) * (self.batch_count / warmup_period) ) * (self.batch_count / warmup_period)
def forward( def forward(

View File

@ -230,7 +230,7 @@ class Conformer(Transformer):
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
) # (T, B, F) ) # (T, B, F)
else: else:
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) x = self.encoder(x, pos_emb, src_key_padding_mask=src_key_padding_mask) # (T, B, F)
if self.normalize_before: if self.normalize_before:
x = self.after_norm(x) x = self.after_norm(x)

View File

@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
DynamicBucketingSampler, DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
@ -314,8 +314,8 @@ class LibriSpeechAsrDataModule:
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SimpleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = SimpleCutSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,

View File

@ -97,6 +97,7 @@ Usage:
import argparse import argparse
import logging import logging
import math import math
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -122,7 +123,7 @@ from beam_search import (
) )
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from icefall import LmScorer, NgramLm from icefall import ContextGraph, LmScorer, NgramLm
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -215,6 +216,7 @@ def get_parser():
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- modified_beam_search_LODR
- fast_beam_search - fast_beam_search
- fast_beam_search_nbest - fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
@ -251,7 +253,7 @@ def get_parser():
type=float, type=float,
default=0.01, default=0.01,
help=""" help="""
Used only when --decoding_method is fast_beam_search_nbest_LG. Used only when --decoding-method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores. It specifies the scale for n-gram LM scores.
""", """,
) )
@ -285,7 +287,7 @@ def get_parser():
type=int, type=int,
default=1, default=1,
help="""Maximum number of symbols per frame. help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""", Used only when --decoding-method is greedy_search""",
) )
parser.add_argument( parser.add_argument(
@ -347,6 +349,27 @@ def get_parser():
help="ID of the backoff symbol in the ngram LM", help="ID of the backoff symbol in the ngram LM",
) )
parser.add_argument(
"--context-score",
type=float,
default=2,
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding-method is modified_beam_search and
modified_beam_search_LODR.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding-method is modified_beam_search and
modified_beam_search_LODR.
""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -359,6 +382,7 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None, LM: Optional[LmScorer] = None,
ngram_lm=None, ngram_lm=None,
ngram_lm_scale: float = 0.0, ngram_lm_scale: float = 0.0,
@ -388,7 +412,7 @@ def decode_one_batch(
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM: LM:
A neural network language model. A neural network language model.
@ -493,6 +517,7 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
context_graph=context_graph,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -515,6 +540,7 @@ def decode_one_batch(
LODR_lm=ngram_lm, LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale, LODR_lm_scale=ngram_lm_scale,
LM=LM, LM=LM,
context_graph=context_graph,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -578,16 +604,22 @@ def decode_one_batch(
key += f"_ngram_lm_scale_{params.ngram_lm_scale}" key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps} return {key: hyps}
elif params.decoding_method in ( elif "modified_beam_search" in params.decoding_method:
"modified_beam_search_lm_rescore", prefix = f"beam_size_{params.beam_size}"
"modified_beam_search_lm_rescore_LODR", if params.decoding_method in (
): "modified_beam_search_lm_rescore",
ans = dict() "modified_beam_search_lm_rescore_LODR",
assert ans_dict is not None ):
for key, hyps in ans_dict.items(): ans = dict()
hyps = [sp.decode(hyp).split() for hyp in hyps] assert ans_dict is not None
ans[f"beam_size_{params.beam_size}_{key}"] = hyps for key, hyps in ans_dict.items():
return ans hyps = [sp.decode(hyp).split() for hyp in hyps]
ans[f"{prefix}_{key}"] = hyps
return ans
else:
if params.has_contexts:
prefix += f"-context-score-{params.context_score}"
return {prefix: hyps}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -599,6 +631,7 @@ def decode_dataset(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None, LM: Optional[LmScorer] = None,
ngram_lm=None, ngram_lm=None,
ngram_lm_scale: float = 0.0, ngram_lm_scale: float = 0.0,
@ -618,7 +651,7 @@ def decode_dataset(
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
@ -649,6 +682,7 @@ def decode_dataset(
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
word_table=word_table, word_table=word_table,
batch=batch, batch=batch,
LM=LM, LM=LM,
@ -744,6 +778,11 @@ def main():
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
if params.iter > 0: if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}" params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else: else:
@ -770,6 +809,12 @@ def main():
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
if params.decoding_method in (
"modified_beam_search",
"modified_beam_search_LODR",
):
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -952,6 +997,18 @@ def main():
decoding_graph = None decoding_graph = None
word_table = None word_table = None
if "modified_beam_search" in params.decoding_method:
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build(sp.encode(contexts))
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -976,6 +1033,7 @@ def main():
sp=sp, sp=sp,
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
LM=LM, LM=LM,
ngram_lm=ngram_lm, ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale, ngram_lm_scale=ngram_lm_scale,

View File

@ -506,6 +506,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64) y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export( torch.onnx.export(
decoder_model, decoder_model,
y, y,

View File

@ -353,6 +353,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64) y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export( torch.onnx.export(
decoder_model, decoder_model,
y, y,

View File

@ -146,6 +146,7 @@ class OnnxModel:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
encoder_model_filename, encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
self.init_encoder_states() self.init_encoder_states()
@ -236,6 +237,7 @@ class OnnxModel:
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
decoder_model_filename, decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -249,6 +251,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
joiner_model_filename, joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -151,12 +151,14 @@ class OnnxModel:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
encoder_model_filename, encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
def init_decoder(self, decoder_model_filename: str): def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
decoder_model_filename, decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -170,6 +172,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
joiner_model_filename, joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

File diff suppressed because it is too large Load Diff

View File

@ -17,7 +17,7 @@ from lhotse.dataset import (
DynamicBucketingSampler, DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
@ -270,8 +270,8 @@ class MGB2AsrDataModule:
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SimpleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = SimpleCutSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,

View File

@ -0,0 +1,39 @@
# Introduction
This recipe includes scripts for training Zipformer model using multiple Chinese datasets.
# Included Training Sets
1. THCHS-30
2. AiShell-{1,2,4}
3. ST-CMDS
4. Primewords
5. MagicData
6. Aidatatang_200zh
7. AliMeeting
8. WeNetSpeech
9. KeSpeech-ASR
|Datset| Number of hours| URL|
|---|---:|---|
|**TOTAL**|14,106|---|
|THCHS-30|35|https://www.openslr.org/18/|
|AiShell-1|170|https://www.openslr.org/33/|
|AiShell-2|1,000|http://www.aishelltech.com/aishell_2|
|AiShell-4|120|https://www.openslr.org/111/|
|ST-CMDS|110|https://www.openslr.org/38/|
|Primewords|99|https://www.openslr.org/47/|
|aidatatang_200zh|200|https://www.openslr.org/62/|
|MagicData|755|https://www.openslr.org/68/|
|AliMeeting|100|https://openslr.org/119/|
|WeNetSpeech|10,000|https://github.com/wenet-e2e/WenetSpeech|
|KeSpeech|1,542|https://github.com/KeSpeech/KeSpeech|
# Included Test Sets
1. Aishell-{1,2,4}
2. Aidatatang_200zh
3. AliMeeting
4. MagicData
5. KeSpeech-ASR
6. WeNetSpeech

View File

@ -0,0 +1,38 @@
## Results
### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model
This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall.
#### Non-streaming
Best results (num of params : ~69M):
The training command:
```
./zipformer/train.py \
--world-size 4 \
--num-epochs 20 \
--use-fp16 1 \
--max-duration 600 \
--num-workers 8
```
The decoding command:
```
./zipformer/decode.py \
--epoch 20 \
--avg 1
```
Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model ( # tokens is 2000, byte fallback enabled).
| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| | 3.2 | 3.67 | 23.15 | 24.78 | 2.91 | 3.04 | 3.59 | 4.03 | 15.68 | 3.68 | 3.12 | 6.69 | 3.19 | 8.01 | 9.32 | 7.05 | 8.78 |
The pre-trained model is available here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2

View File

@ -0,0 +1,37 @@
#!/usr/bin/env python3
"""
This script takes `bpe.model` as input and generates a file `tokens.txt`
from it.
Usage:
./bpe_model_to_tokens.py /path/to/input/bpe.model > tokens.txt
"""
import argparse
import sentencepiece as spm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"bpe_model",
type=str,
help="Path to the input bpe.model",
)
return parser.parse_args()
def main():
args = get_args()
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
for i in range(sp.vocab_size()):
print(sp.id_to_piece(i), i)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compile_lg.py

View File

@ -0,0 +1,93 @@
#!/usr/bin/env python3
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
# Copyright 2023 Xiaomi Corp. (Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
import torch
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_kespeech_dev_test():
in_out_dir = Path("data/fbank/kespeech")
# number of workers in dataloader
num_workers = 42
# number of seconds in a batch
batch_duration = 600
subsets = (
"dev_phase1",
"dev_phase2",
"test",
)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
logging.info(f"device: {device}")
for partition in subsets:
cuts_path = in_out_dir / f"kespeech-asr_cuts_{partition}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
raw_cuts_path = in_out_dir / f"kespeech-asr_cuts_{partition}_raw.jsonl.gz"
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
logging.info("Splitting cuts into smaller chunks")
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
logging.info("Computing features")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{in_out_dir}/feats_{partition}",
num_workers=num_workers,
batch_duration=batch_duration,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
logging.info(f"Saving to {cuts_path}")
cut_set.to_file(cuts_path)
def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_kespeech_dev_test()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,180 @@
#!/usr/bin/env python3
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
# Copyright 2023 Xiaomi Corp. (Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from datetime import datetime
from pathlib import Path
import torch
from lhotse import (
CutSet,
KaldifeatFbank,
KaldifeatFbankConfig,
LilcomChunkyWriter,
set_audio_duration_mismatch_tolerance,
set_caching_enabled,
)
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--training-subset",
type=str,
default="train_phase1",
choices=["train_phase1", "train_phase2"],
help="The training subset for computing fbank feature.",
)
parser.add_argument(
"--num-workers",
type=int,
default=20,
help="Number of dataloading workers used for reading the audio.",
)
parser.add_argument(
"--batch-duration",
type=float,
default=600.0,
help="The maximum number of audio seconds in a batch."
"Determines batch size dynamically.",
)
parser.add_argument(
"--num-splits",
type=int,
required=True,
help="The number of splits of the given subset",
)
parser.add_argument(
"--start",
type=int,
default=0,
help="Process pieces starting from this number (inclusive).",
)
parser.add_argument(
"--stop",
type=int,
default=-1,
help="Stop processing pieces until this number (exclusive).",
)
return parser
def compute_fbank_kespeech_splits(args):
subset = args.training_subset
subset = str(subset)
num_splits = args.num_splits
output_dir = f"data/fbank/kespeech/{subset}_split_{num_splits}"
output_dir = Path(output_dir)
assert output_dir.exists(), f"{output_dir} does not exist!"
num_digits = len(str(num_splits))
start = args.start
stop = args.stop
if stop < start:
stop = num_splits
stop = min(stop, num_splits)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
logging.info(f"device: {device}")
set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance
set_caching_enabled(False)
for i in range(start, stop):
idx = f"{i + 1}".zfill(num_digits)
logging.info(f"Processing {idx}/{num_splits}")
cuts_path = output_dir / f"kespeech-asr_cuts_{subset}.{idx}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
raw_cuts_path = output_dir / f"kespeech-asr_cuts_{subset}_raw.{idx}.jsonl.gz"
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
logging.info("Splitting cuts into smaller chunks.")
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
logging.info("Computing features")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/feats_{subset}_{idx}",
num_workers=args.num_workers,
batch_duration=args.batch_duration,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
logging.info(f"Saving to {cuts_path}")
cut_set.to_file(cuts_path)
def main():
now = datetime.now()
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
log_filename = "log-compute_fbank_kespeech_splits"
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
log_filename = f"{log_filename}-{date_time}"
logging.basicConfig(
filename=log_filename,
format=formatter,
level=logging.INFO,
filemode="w",
)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
console.setFormatter(logging.Formatter(formatter))
logging.getLogger("").addHandler(console)
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
compute_fbank_kespeech_splits(args)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,122 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang
# Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the MagicData dataset.
It looks for manifests in the directory data/manifests/magicdata.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False):
src_dir = Path("data/manifests/magicdata")
output_dir = Path("data/fbank")
num_jobs = min(30, os.cpu_count())
dataset_parts = ("train", "test", "dev")
prefix = "magicdata"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition and speed_perturb:
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--speed-perturb",
type=bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
return parser.parse_args()
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_magicdata(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
)

View File

@ -0,0 +1,122 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang
# Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the Primewords dataset.
It looks for manifests in the directory data/manifests/primewords.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False):
src_dir = Path("data/manifests/primewords")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
dataset_parts = ("train",)
prefix = "primewords"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition and speed_perturb:
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--speed-perturb",
type=bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
return parser.parse_args()
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_primewords(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
)

View File

@ -0,0 +1,121 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang
# Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the ST-CMDS dataset.
It looks for manifests in the directory data/manifests/stcmds.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False):
src_dir = Path("data/manifests/stcmds")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
dataset_parts = ("train",)
prefix = "stcmds"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition and speed_perturb:
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--speed-perturb",
type=bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
return parser.parse_args()
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_stcmds(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
)

View File

@ -0,0 +1,127 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang
# Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the THCHS-30 dataset.
It looks for manifests in the directory data/manifests/thchs30.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False):
src_dir = Path("data/manifests/thchs30")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
dataset_parts = (
"train",
"dev",
"test",
)
prefix = "thchs_30"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (
(cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1))
if speed_perturb
else cut_set
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--speed-perturb",
type=bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
return parser.parse_args()
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_thchs30(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
)

View File

@ -0,0 +1 @@
../../../wenetspeech/ASR/local/prepare_char.py

Some files were not shown because too many files have changed in this diff Show More