mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 22:54:18 +00:00
Merge branch 'k2-fsa:master' into dev_zipformer_cn
This commit is contained in:
commit
023f6e05d4
@ -29,6 +29,9 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
|
||||
ls -lh data/fbank
|
||||
ls -lh pruned_transducer_stateless2/exp
|
||||
|
||||
ln -s data/fbank/cuts_DEV.jsonl.gz data/fbank/gigaspeech_cuts_DEV.jsonl.gz
|
||||
ln -s data/fbank/cuts_TEST.jsonl.gz data/fbank/gigaspeech_cuts_TEST.jsonl.gz
|
||||
|
||||
log "Decoding dev and test"
|
||||
|
||||
# use a small value for decoding with CPU
|
||||
|
51
.github/scripts/run-multi-zh_hans-zipformer.sh
vendored
Executable file
51
.github/scripts/run-multi-zh_hans-zipformer.sh
vendored
Executable file
@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/multi_zh-hans/ASR
|
||||
|
||||
repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git lfs install
|
||||
git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
pushd $repo/exp
|
||||
ln -s epoch-20.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
ls -lh $repo/exp/*.pt
|
||||
|
||||
|
||||
./zipformer/pretrained.py \
|
||||
--checkpoint $repo/exp/epoch-99.pt \
|
||||
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
||||
--method greedy_search \
|
||||
$repo/test_wavs/DEV_T0000000000.wav \
|
||||
$repo/test_wavs/DEV_T0000000001.wav \
|
||||
$repo/test_wavs/DEV_T0000000002.wav
|
||||
|
||||
for method in modified_beam_search fast_beam_search; do
|
||||
log "$method"
|
||||
|
||||
./zipformer/pretrained.py \
|
||||
--method $method \
|
||||
--beam-size 4 \
|
||||
--checkpoint $repo/exp/epoch-99.pt \
|
||||
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
||||
$repo/test_wavs/DEV_T0000000000.wav \
|
||||
$repo/test_wavs/DEV_T0000000001.wav \
|
||||
$repo/test_wavs/DEV_T0000000002.wav
|
||||
done
|
2
.github/workflows/run-aishell-2022-06-20.yml
vendored
2
.github/workflows/run-aishell-2022-06-20.yml
vendored
@ -45,7 +45,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
|
@ -44,7 +44,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
|
@ -44,7 +44,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
|
@ -44,7 +44,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
|
@ -44,7 +44,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
|
@ -44,7 +44,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
|
@ -44,7 +44,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
|
@ -44,7 +44,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
|
84
.github/workflows/run-multi-zh_hans-zipformer.yml
vendored
Normal file
84
.github/workflows/run-multi-zh_hans-zipformer.yml
vendored
Normal file
@ -0,0 +1,84 @@
|
||||
# Copyright 2023 Xiaomi Corp. (author: Zengrui Jin)
|
||||
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: run-multi-zh_hans-zipformer
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
|
||||
concurrency:
|
||||
group: run_multi-zh_hans_zipformer-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
run_multi-zh_hans_zipformer:
|
||||
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'pip'
|
||||
cache-dependency-path: '**/requirements-ci.txt'
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||
pip uninstall -y protobuf
|
||||
pip install --no-binary protobuf protobuf==3.20.*
|
||||
|
||||
- name: Cache kaldifeat
|
||||
id: my-cache
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/install-kaldifeat.sh
|
||||
|
||||
- name: Inference with pre-trained model
|
||||
shell: bash
|
||||
env:
|
||||
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||
run: |
|
||||
sudo apt-get -qq install git-lfs tree
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
|
||||
.github/scripts/run-multi-zh_hans-zipformer.sh
|
@ -34,7 +34,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
|
@ -34,7 +34,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
|
@ -34,7 +34,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
|
@ -34,7 +34,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
|
@ -338,7 +338,7 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
|
||||
|
||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
|
||||
|
||||
The best results for Chinese CER(%) and English WER(%) respectivly (zh: Chinese, en: English):
|
||||
The best results for Chinese CER(%) and English WER(%) respectively (zh: Chinese, en: English):
|
||||
|decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en |
|
||||
|--|--|--|--|--|--|--|
|
||||
|greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13|
|
||||
|
@ -95,4 +95,7 @@ rst_epilog = """
|
||||
.. _k2: https://github.com/k2-fsa/k2
|
||||
.. _lhotse: https://github.com/lhotse-speech/lhotse
|
||||
.. _yesno: https://www.openslr.org/1/
|
||||
.. _Next-gen Kaldi: https://github.com/k2-fsa
|
||||
.. _Kaldi: https://github.com/kaldi-asr/kaldi
|
||||
.. _lilcom: https://github.com/danpovey/lilcom
|
||||
"""
|
||||
|
@ -71,9 +71,12 @@ As the initial step, let's download the pre-trained model.
|
||||
.. code-block:: bash
|
||||
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
||||
$ cd ../data/lang_bpe_500
|
||||
$ git lfs pull --include bpe.model
|
||||
$ cd ../../..
|
||||
|
||||
To test the model, let's have a look at the decoding results **without** using LM. This can be done via the following command:
|
||||
|
||||
|
@ -34,9 +34,12 @@ As the initial step, let's download the pre-trained model.
|
||||
.. code-block:: bash
|
||||
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
||||
$ cd ../data/lang_bpe_500
|
||||
$ git lfs pull --include bpe.model
|
||||
$ cd ../../..
|
||||
|
||||
As usual, we first test the model's performance without external LM. This can be done via the following command:
|
||||
|
||||
|
@ -32,9 +32,12 @@ As the initial step, let's download the pre-trained model.
|
||||
.. code-block:: bash
|
||||
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
||||
$ cd ../data/lang_bpe_500
|
||||
$ git lfs pull --include bpe.model
|
||||
$ cd ../../..
|
||||
|
||||
To test the model, let's have a look at the decoding results without using LM. This can be done via the following command:
|
||||
|
||||
|
180
docs/source/for-dummies/data-preparation.rst
Normal file
180
docs/source/for-dummies/data-preparation.rst
Normal file
@ -0,0 +1,180 @@
|
||||
.. _dummies_tutorial_data_preparation:
|
||||
|
||||
Data Preparation
|
||||
================
|
||||
|
||||
After :ref:`dummies_tutorial_environment_setup`, we can start preparing the
|
||||
data for training and decoding.
|
||||
|
||||
The first step is to prepare the data for training. We have already provided
|
||||
`prepare.sh <https://github.com/k2-fsa/icefall/blob/master/egs/yesno/ASR/prepare.sh>`_
|
||||
that would prepare everything required for training.
|
||||
|
||||
.. code-block::
|
||||
|
||||
cd /tmp/icefall
|
||||
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
||||
cd egs/yesno/ASR
|
||||
|
||||
./prepare.sh
|
||||
|
||||
Note that in each recipe from `icefall`_, there exists a file ``prepare.sh``,
|
||||
which you should run before you run anything else.
|
||||
|
||||
That is all you need for data preparation.
|
||||
|
||||
For the more curious
|
||||
--------------------
|
||||
|
||||
If you are wondering how to prepare your own dataset, please refer to the following
|
||||
URLs for more details:
|
||||
|
||||
- `<https://github.com/lhotse-speech/lhotse/tree/master/lhotse/recipes>`_
|
||||
|
||||
It contains recipes for a variety of dataset. If you want to add your own
|
||||
dataset, please read recipes in this folder first.
|
||||
|
||||
- `<https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/yesno.py>`_
|
||||
|
||||
The `yesno`_ recipe in `lhotse`_.
|
||||
|
||||
If you already have a `Kaldi`_ dataset directory, which contains files like
|
||||
``wav.scp``, ``feats.scp``, then you can refer to `<https://lhotse.readthedocs.io/en/latest/kaldi.html#example>`_.
|
||||
|
||||
A quick look to the generated files
|
||||
-----------------------------------
|
||||
|
||||
``./prepare.sh`` puts generated files into two directories:
|
||||
|
||||
- ``download``
|
||||
- ``data``
|
||||
|
||||
download
|
||||
^^^^^^^^
|
||||
|
||||
The ``download`` directory contains downloaded dataset files:
|
||||
|
||||
.. code-block:: bas
|
||||
|
||||
tree -L 1 ./download/
|
||||
|
||||
./download/
|
||||
|-- waves_yesno
|
||||
`-- waves_yesno.tar.gz
|
||||
|
||||
.. hint::
|
||||
|
||||
Please refer to `<https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/yesno.py#L41>`_
|
||||
for how the data is downloaded and extracted.
|
||||
|
||||
data
|
||||
^^^^
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
tree ./data/
|
||||
|
||||
./data/
|
||||
|-- fbank
|
||||
| |-- yesno_cuts_test.jsonl.gz
|
||||
| |-- yesno_cuts_train.jsonl.gz
|
||||
| |-- yesno_feats_test.lca
|
||||
| `-- yesno_feats_train.lca
|
||||
|-- lang_phone
|
||||
| |-- HLG.pt
|
||||
| |-- L.pt
|
||||
| |-- L_disambig.pt
|
||||
| |-- Linv.pt
|
||||
| |-- lexicon.txt
|
||||
| |-- lexicon_disambig.txt
|
||||
| |-- tokens.txt
|
||||
| `-- words.txt
|
||||
|-- lm
|
||||
| |-- G.arpa
|
||||
| `-- G.fst.txt
|
||||
`-- manifests
|
||||
|-- yesno_recordings_test.jsonl.gz
|
||||
|-- yesno_recordings_train.jsonl.gz
|
||||
|-- yesno_supervisions_test.jsonl.gz
|
||||
`-- yesno_supervisions_train.jsonl.gz
|
||||
|
||||
4 directories, 18 files
|
||||
|
||||
**data/manifests**:
|
||||
|
||||
This directory contains manifests. They are used to generate files in
|
||||
``data/fbank``.
|
||||
|
||||
To give you an idea of what it contains, we examine the first few lines of
|
||||
the manifests related to the ``train`` dataset.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd data/manifests
|
||||
gunzip -c yesno_recordings_train.jsonl.gz | head -n 3
|
||||
|
||||
The output is given below:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
{"id": "0_0_0_0_1_1_1_1", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_0_1_1_1_1.wav"}], "sampling_rate": 8000, "num_samples": 50800, "duration": 6.35, "channel_ids": [0]}
|
||||
{"id": "0_0_0_1_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_1_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48880, "duration": 6.11, "channel_ids": [0]}
|
||||
{"id": "0_0_1_0_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_1_0_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48160, "duration": 6.02, "channel_ids": [0]}
|
||||
|
||||
Please refer to `<https://github.com/lhotse-speech/lhotse/blob/master/lhotse/audio.py#L300>`_
|
||||
for the meaning of each field per line.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
gunzip -c yesno_supervisions_train.jsonl.gz | head -n 3
|
||||
|
||||
The output is given below:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
{"id": "0_0_0_0_1_1_1_1", "recording_id": "0_0_0_0_1_1_1_1", "start": 0.0, "duration": 6.35, "channel": 0, "text": "NO NO NO NO YES YES YES YES", "language": "Hebrew"}
|
||||
{"id": "0_0_0_1_0_1_1_0", "recording_id": "0_0_0_1_0_1_1_0", "start": 0.0, "duration": 6.11, "channel": 0, "text": "NO NO NO YES NO YES YES NO", "language": "Hebrew"}
|
||||
{"id": "0_0_1_0_0_1_1_0", "recording_id": "0_0_1_0_0_1_1_0", "start": 0.0, "duration": 6.02, "channel": 0, "text": "NO NO YES NO NO YES YES NO", "language": "Hebrew"}
|
||||
|
||||
Please refer to `<https://github.com/lhotse-speech/lhotse/blob/master/lhotse/supervision.py#L510>`_
|
||||
for the meaning of each field per line.
|
||||
|
||||
**data/fbank**:
|
||||
|
||||
This directory contains everything from ``data/manifests``. Furthermore, it also contains features
|
||||
for training.
|
||||
|
||||
``data/fbank/yesno_feats_train.lca`` contains the features for the train dataset.
|
||||
Features are compressed using `lilcom`_.
|
||||
|
||||
``data/fbank/yesno_cuts_train.jsonl.gz`` stores the `CutSet <https://github.com/lhotse-speech/lhotse/blob/master/lhotse/cut/set.py#L72>`_,
|
||||
which stores `RecordingSet <https://github.com/lhotse-speech/lhotse/blob/master/lhotse/audio.py#L928>`_,
|
||||
`SupervisionSet <https://github.com/lhotse-speech/lhotse/blob/master/lhotse/supervision.py#L510>`_,
|
||||
and `FeatureSet <https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/base.py#L593>`_.
|
||||
|
||||
To give you an idea about what it looks like, we can run the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd data/fbank
|
||||
|
||||
gunzip -c yesno_cuts_train.jsonl.gz | head -n 3
|
||||
|
||||
The output is given below:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
{"id": "0_0_0_0_1_1_1_1-0", "start": 0, "duration": 6.35, "channel": 0, "supervisions": [{"id": "0_0_0_0_1_1_1_1", "recording_id": "0_0_0_0_1_1_1_1", "start": 0.0, "duration": 6.35, "channel": 0, "text": "NO NO NO NO YES YES YES YES", "language": "Hebrew"}], "features": {"type": "kaldi-fbank", "num_frames": 635, "num_features": 23, "frame_shift": 0.01, "sampling_rate": 8000, "start": 0, "duration": 6.35, "storage_type": "lilcom_chunky", "storage_path": "data/fbank/yesno_feats_train.lca", "storage_key": "0,13000,3570", "channels": 0}, "recording": {"id": "0_0_0_0_1_1_1_1", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_0_1_1_1_1.wav"}], "sampling_rate": 8000, "num_samples": 50800, "duration": 6.35, "channel_ids": [0]}, "type": "MonoCut"}
|
||||
{"id": "0_0_0_1_0_1_1_0-1", "start": 0, "duration": 6.11, "channel": 0, "supervisions": [{"id": "0_0_0_1_0_1_1_0", "recording_id": "0_0_0_1_0_1_1_0", "start": 0.0, "duration": 6.11, "channel": 0, "text": "NO NO NO YES NO YES YES NO", "language": "Hebrew"}], "features": {"type": "kaldi-fbank", "num_frames": 611, "num_features": 23, "frame_shift": 0.01, "sampling_rate": 8000, "start": 0, "duration": 6.11, "storage_type": "lilcom_chunky", "storage_path": "data/fbank/yesno_feats_train.lca", "storage_key": "16570,12964,2929", "channels": 0}, "recording": {"id": "0_0_0_1_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_1_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48880, "duration": 6.11, "channel_ids": [0]}, "type": "MonoCut"}
|
||||
{"id": "0_0_1_0_0_1_1_0-2", "start": 0, "duration": 6.02, "channel": 0, "supervisions": [{"id": "0_0_1_0_0_1_1_0", "recording_id": "0_0_1_0_0_1_1_0", "start": 0.0, "duration": 6.02, "channel": 0, "text": "NO NO YES NO NO YES YES NO", "language": "Hebrew"}], "features": {"type": "kaldi-fbank", "num_frames": 602, "num_features": 23, "frame_shift": 0.01, "sampling_rate": 8000, "start": 0, "duration": 6.02, "storage_type": "lilcom_chunky", "storage_path": "data/fbank/yesno_feats_train.lca", "storage_key": "32463,12936,2696", "channels": 0}, "recording": {"id": "0_0_1_0_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_1_0_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48160, "duration": 6.02, "channel_ids": [0]}, "type": "MonoCut"}
|
||||
|
||||
Note that ``yesno_cuts_train.jsonl.gz`` only stores the information about how to read the features.
|
||||
The actual features are stored separately in ``data/fbank/yesno_feats_train.lca``.
|
||||
|
||||
**data/lang**:
|
||||
|
||||
This directory contains the lexicon.
|
||||
|
||||
**data/lm**:
|
||||
|
||||
This directory contains language models.
|
39
docs/source/for-dummies/decoding.rst
Normal file
39
docs/source/for-dummies/decoding.rst
Normal file
@ -0,0 +1,39 @@
|
||||
.. _dummies_tutorial_decoding:
|
||||
|
||||
Decoding
|
||||
========
|
||||
|
||||
After :ref:`dummies_tutorial_training`, we can start decoding.
|
||||
|
||||
The command to start the decoding is quite simple:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd /tmp/icefall
|
||||
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
||||
cd egs/yesno/ASR
|
||||
|
||||
# We use CPU for decoding by setting the following environment variable
|
||||
export CUDA_VISIBLE_DEVICES=""
|
||||
|
||||
./tdnn/decode.py
|
||||
|
||||
The output logs are given below:
|
||||
|
||||
.. literalinclude:: ./code/decoding-yesno.txt
|
||||
|
||||
For the more curious
|
||||
--------------------
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
./tdnn/decode.py --help
|
||||
|
||||
will print the usage information about ``./tdnn/decode.py``. For instance, you
|
||||
can specify:
|
||||
|
||||
- ``--epoch`` to use which checkpoint for decoding
|
||||
- ``--avg`` to select how many checkpoints to use for model averaging
|
||||
|
||||
You usually try different combinations of ``--epoch`` and ``--avg`` and select
|
||||
one that leads to the lowest WER (`Word Error Rate <https://en.wikipedia.org/wiki/Word_error_rate>`_).
|
121
docs/source/for-dummies/environment-setup.rst
Normal file
121
docs/source/for-dummies/environment-setup.rst
Normal file
@ -0,0 +1,121 @@
|
||||
.. _dummies_tutorial_environment_setup:
|
||||
|
||||
Environment setup
|
||||
=================
|
||||
|
||||
We will create an environment for `Next-gen Kaldi`_ that runs on ``CPU``
|
||||
in this tutorial.
|
||||
|
||||
.. note::
|
||||
|
||||
Since the `yesno`_ dataset used in this tutorial is very tiny, training on
|
||||
``CPU`` works very well for it.
|
||||
|
||||
If your dataset is very large, e.g., hundreds or thousands of hours of
|
||||
training data, please follow :ref:`install icefall` to install `icefall`_
|
||||
that works with ``GPU``.
|
||||
|
||||
|
||||
Create a virtual environment
|
||||
----------------------------
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
virtualenv -p python3 /tmp/icefall_env
|
||||
|
||||
The above command creates a virtual environment in the directory ``/tmp/icefall_env``.
|
||||
You can select any directory you want.
|
||||
|
||||
The output of the above command is given below:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
Already using interpreter /usr/bin/python3
|
||||
Using base prefix '/usr'
|
||||
New python executable in /tmp/icefall_env/bin/python3
|
||||
Also creating executable in /tmp/icefall_env/bin/python
|
||||
Installing setuptools, pkg_resources, pip, wheel...done.
|
||||
|
||||
Now we can activate the environment using:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
source /tmp/icefall_env/bin/activate
|
||||
|
||||
Install dependencies
|
||||
--------------------
|
||||
|
||||
.. warning::
|
||||
|
||||
Remeber to activate your virtual environment before you continue!
|
||||
|
||||
After activating the virtual environment, we can use the following command
|
||||
to install dependencies of `icefall`_:
|
||||
|
||||
.. hint::
|
||||
|
||||
Remeber that we will run this tutorial on ``CPU``, so we install
|
||||
dependencies required only by running on ``CPU``.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Caution: Installation order matters!
|
||||
|
||||
# We use torch 2.0.0 and torchaduio 2.0.0 in this tutorial.
|
||||
# Other versions should also work.
|
||||
|
||||
pip install torch==2.0.0+cpu torchaudio==2.0.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
|
||||
# If you are using macOS or Windows, please use the following command to install torch and torchaudio
|
||||
# pip install torch==2.0.0 torchaudio==2.0.0 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
|
||||
# Now install k2
|
||||
# Please refer to https://k2-fsa.github.io/k2/installation/from_wheels.html#linux-cpu-example
|
||||
|
||||
pip install k2==1.24.3.dev20230726+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu.html
|
||||
|
||||
# Install the latest version of lhotse
|
||||
|
||||
pip install git+https://github.com/lhotse-speech/lhotse
|
||||
|
||||
|
||||
Install icefall
|
||||
---------------
|
||||
|
||||
We will put the source code of `icefall`_ into the directory ``/tmp``
|
||||
You can select any directory you want.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd /tmp
|
||||
git clone https://github.com/k2-fsa/icefall
|
||||
cd icefall
|
||||
pip install -r ./requirements.txt
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Anytime we want to use icefall, we have to set the following
|
||||
# environment variable
|
||||
|
||||
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
||||
|
||||
.. hint::
|
||||
|
||||
If you get the following error during this tutorial:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
ModuleNotFoundError: No module named 'icefall'
|
||||
|
||||
please set the above environment variable to fix it.
|
||||
|
||||
|
||||
Congratulations! You have installed `icefall`_ successfully.
|
||||
|
||||
For the more curious
|
||||
--------------------
|
||||
|
||||
`icefall`_ contains a collection of Python scripts and you don't need to
|
||||
use ``python3 setup.py install`` or ``pip install icefall`` to install it.
|
||||
All you need to do is to download the code and set the environment variable
|
||||
``PYTHONPATH``.
|
34
docs/source/for-dummies/index.rst
Normal file
34
docs/source/for-dummies/index.rst
Normal file
@ -0,0 +1,34 @@
|
||||
Icefall for dummies tutorial
|
||||
============================
|
||||
|
||||
This tutorial walks you step by step about how to create a simple
|
||||
ASR (`Automatic Speech Recognition <https://en.wikipedia.org/wiki/Speech_recognition>`_)
|
||||
system with `Next-gen Kaldi`_.
|
||||
|
||||
We use the `yesno`_ dataset for demonstration. We select it out of two reasons:
|
||||
|
||||
- It is quite tiny, containing only about 12 minutes of data
|
||||
- The training can be finished within 20 seconds on ``CPU``.
|
||||
|
||||
That also means you don't need a ``GPU`` to run this tutorial.
|
||||
|
||||
Let's get started!
|
||||
|
||||
Please follow items below **sequentially**.
|
||||
|
||||
.. note::
|
||||
|
||||
The :ref:`dummies_tutorial_data_preparation` runs only on Linux and on macOS.
|
||||
All other parts run on Linux, macOS, and Windows.
|
||||
|
||||
Help from the community is appreciated to port the :ref:`dummies_tutorial_data_preparation`
|
||||
to Windows.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
./environment-setup.rst
|
||||
./data-preparation.rst
|
||||
./training.rst
|
||||
./decoding.rst
|
||||
./model-export.rst
|
310
docs/source/for-dummies/model-export.rst
Normal file
310
docs/source/for-dummies/model-export.rst
Normal file
@ -0,0 +1,310 @@
|
||||
Model Export
|
||||
============
|
||||
|
||||
There are three ways to export a pre-trained model.
|
||||
|
||||
- Export the model parameters via `model.state_dict() <https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.state_dict>`_
|
||||
- Export via `torchscript <https://pytorch.org/docs/stable/jit.html>`_: either `torch.jit.script() <https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script>`_ or `torch.jit.trace() <https://pytorch.org/docs/stable/generated/torch.jit.trace.html>`_
|
||||
- Export to `ONNX`_ via `torch.onnx.export() <https://pytorch.org/docs/stable/onnx.html>`_
|
||||
|
||||
Each method is explained below in detail.
|
||||
|
||||
Export the model parameters via model.state_dict()
|
||||
---------------------------------------------------
|
||||
|
||||
The command for this kind of export is
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd /tmp/icefall
|
||||
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
||||
cd egs/yesno/ASR
|
||||
|
||||
# assume that "--epoch 14 --avg 2" produces the lowest WER.
|
||||
|
||||
./tdnn/export.py --epoch 14 --avg 2
|
||||
|
||||
The output logs are given below:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
2023-08-16 20:42:03,912 INFO [export.py:76] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'jit': False}
|
||||
2023-08-16 20:42:03,913 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
||||
2023-08-16 20:42:03,950 INFO [export.py:93] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
|
||||
2023-08-16 20:42:03,971 INFO [export.py:106] Not using torch.jit.script
|
||||
2023-08-16 20:42:03,974 INFO [export.py:111] Saved to tdnn/exp/pretrained.pt
|
||||
|
||||
We can see from the logs that the exported model is saved to the file ``tdnn/exp/pretrained.pt``.
|
||||
|
||||
To give you an idea of what ``tdnn/exp/pretrained.pt`` contains, we can use the following command:
|
||||
|
||||
.. code-block:: python3
|
||||
|
||||
>>> import torch
|
||||
>>> m = torch.load("tdnn/exp/pretrained.pt")
|
||||
>>> list(m.keys())
|
||||
['model']
|
||||
>>> list(m["model"].keys())
|
||||
['tdnn.0.weight', 'tdnn.0.bias', 'tdnn.2.running_mean', 'tdnn.2.running_var', 'tdnn.2.num_batches_tracked', 'tdnn.3.weight', 'tdnn.3.bias', 'tdnn.5.running_mean', 'tdnn.5.running_var', 'tdnn.5.num_batches_tracked', 'tdnn.6.weight', 'tdnn.6.bias', 'tdnn.8.running_mean', 'tdnn.8.running_var', 'tdnn.8.num_batches_tracked', 'output_linear.weight', 'output_linear.bias']
|
||||
|
||||
We can use ``tdnn/exp/pretrained.pt`` in the following way with ``./tdnn/decode.py``:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd tdnn/exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
cd ../..
|
||||
|
||||
./tdnn/decode.py --epoch 99 --avg 1
|
||||
|
||||
The output logs of the above command are given below:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
2023-08-16 20:45:48,089 INFO [decode.py:262] Decoding started
|
||||
2023-08-16 20:45:48,090 INFO [decode.py:263] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'feature_dim': 23, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 99, 'avg': 1, 'export': False, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': False, 'k2-git-sha1': 'ad79f1c699c684de9785ed6ca5edb805a41f78c3', 'k2-git-date': 'Wed Jul 26 09:30:42 2023', 'lhotse-version': '1.16.0.dev+git.aa073f6.clean', 'torch-version': '2.0.0', 'torch-cuda-available': False, 'torch-cuda-version': None, 'python-version': '3.1', 'icefall-git-branch': 'master', 'icefall-git-sha1': '9a47c08-clean', 'icefall-git-date': 'Mon Aug 14 22:10:50 2023', 'icefall-path': '/private/tmp/icefall', 'k2-path': '/private/tmp/icefall_env/lib/python3.11/site-packages/k2/__init__.py', 'lhotse-path': '/private/tmp/icefall_env/lib/python3.11/site-packages/lhotse/__init__.py', 'hostname': 'fangjuns-MacBook-Pro.local', 'IP address': '127.0.0.1'}}
|
||||
2023-08-16 20:45:48,092 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
||||
2023-08-16 20:45:48,103 INFO [decode.py:272] device: cpu
|
||||
2023-08-16 20:45:48,109 INFO [checkpoint.py:112] Loading checkpoint from tdnn/exp/epoch-99.pt
|
||||
2023-08-16 20:45:48,115 INFO [asr_datamodule.py:218] About to get test cuts
|
||||
2023-08-16 20:45:48,115 INFO [asr_datamodule.py:253] About to get test cuts
|
||||
2023-08-16 20:45:50,386 INFO [decode.py:203] batch 0/?, cuts processed until now is 4
|
||||
2023-08-16 20:45:50,556 INFO [decode.py:240] The transcripts are stored in tdnn/exp/recogs-test_set.txt
|
||||
2023-08-16 20:45:50,557 INFO [utils.py:564] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
|
||||
2023-08-16 20:45:50,558 INFO [decode.py:248] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
|
||||
2023-08-16 20:45:50,559 INFO [decode.py:315] Done!
|
||||
|
||||
We can see that it produces an identical WER as before.
|
||||
|
||||
We can also use it to decode files with the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# ./tdnn/pretrained.py requires kaldifeat
|
||||
#
|
||||
# Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html
|
||||
# for how to install kaldifeat
|
||||
|
||||
pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
|
||||
|
||||
./tdnn/pretrained.py \
|
||||
--checkpoint ./tdnn/exp/pretrained.pt \
|
||||
--HLG ./data/lang_phone/HLG.pt \
|
||||
--words-file ./data/lang_phone/words.txt \
|
||||
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
||||
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
||||
|
||||
The output is given below:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
2023-08-16 20:53:19,208 INFO [pretrained.py:136] {'feature_dim': 23, 'num_classes': 4, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'checkpoint': './tdnn/exp/pretrained.pt', 'words_file': './data/lang_phone/words.txt', 'HLG': './data/lang_phone/HLG.pt', 'sound_files': ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']}
|
||||
2023-08-16 20:53:19,208 INFO [pretrained.py:142] device: cpu
|
||||
2023-08-16 20:53:19,208 INFO [pretrained.py:144] Creating model
|
||||
2023-08-16 20:53:19,212 INFO [pretrained.py:156] Loading HLG from ./data/lang_phone/HLG.pt
|
||||
2023-08-16 20:53:19,213 INFO [pretrained.py:160] Constructing Fbank computer
|
||||
2023-08-16 20:53:19,213 INFO [pretrained.py:170] Reading sound files: ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']
|
||||
2023-08-16 20:53:19,224 INFO [pretrained.py:176] Decoding started
|
||||
2023-08-16 20:53:19,304 INFO [pretrained.py:212]
|
||||
download/waves_yesno/0_0_0_1_0_0_0_1.wav:
|
||||
NO NO NO YES NO NO NO YES
|
||||
|
||||
download/waves_yesno/0_0_1_0_0_0_1_0.wav:
|
||||
NO NO YES NO NO NO YES NO
|
||||
|
||||
|
||||
2023-08-16 20:53:19,304 INFO [pretrained.py:214] Decoding Done
|
||||
|
||||
|
||||
Export via torch.jit.script()
|
||||
-----------------------------
|
||||
|
||||
The command for this kind of export is
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd /tmp/icefall
|
||||
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
||||
cd egs/yesno/ASR
|
||||
|
||||
# assume that "--epoch 14 --avg 2" produces the lowest WER.
|
||||
|
||||
./tdnn/export.py --epoch 14 --avg 2 --jit true
|
||||
|
||||
The output logs are given below:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
2023-08-16 20:47:44,666 INFO [export.py:76] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'jit': True}
|
||||
2023-08-16 20:47:44,667 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
||||
2023-08-16 20:47:44,670 INFO [export.py:93] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
|
||||
2023-08-16 20:47:44,677 INFO [export.py:100] Using torch.jit.script
|
||||
2023-08-16 20:47:44,843 INFO [export.py:104] Saved to tdnn/exp/cpu_jit.pt
|
||||
|
||||
From the output logs we can see that the generated file is saved to ``tdnn/exp/cpu_jit.pt``.
|
||||
|
||||
Don't be confused by the name ``cpu_jit.pt``. The ``cpu`` part means the model is moved to
|
||||
CPU before exporting. That means, when you load it with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
torch.jit.load()
|
||||
|
||||
you don't need to specify the argument `map_location <https://pytorch.org/docs/stable/generated/torch.jit.load.html#torch.jit.load>`_
|
||||
and it resides on CPU by default.
|
||||
|
||||
To use ``tdnn/exp/cpu_jit.pt`` with `icefall`_ to decode files, we can use:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# ./tdnn/jit_pretrained.py requires kaldifeat
|
||||
#
|
||||
# Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html
|
||||
# for how to install kaldifeat
|
||||
|
||||
pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
|
||||
|
||||
|
||||
./tdnn/jit_pretrained.py \
|
||||
--nn-model ./tdnn/exp/cpu_jit.pt \
|
||||
--HLG ./data/lang_phone/HLG.pt \
|
||||
--words-file ./data/lang_phone/words.txt \
|
||||
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
||||
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
||||
|
||||
The output is given below:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
2023-08-16 20:56:00,603 INFO [jit_pretrained.py:121] {'feature_dim': 23, 'num_classes': 4, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'nn_model': './tdnn/exp/cpu_jit.pt', 'words_file': './data/lang_phone/words.txt', 'HLG': './data/lang_phone/HLG.pt', 'sound_files': ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']}
|
||||
2023-08-16 20:56:00,603 INFO [jit_pretrained.py:127] device: cpu
|
||||
2023-08-16 20:56:00,603 INFO [jit_pretrained.py:129] Loading torchscript model
|
||||
2023-08-16 20:56:00,640 INFO [jit_pretrained.py:134] Loading HLG from ./data/lang_phone/HLG.pt
|
||||
2023-08-16 20:56:00,641 INFO [jit_pretrained.py:138] Constructing Fbank computer
|
||||
2023-08-16 20:56:00,641 INFO [jit_pretrained.py:148] Reading sound files: ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']
|
||||
2023-08-16 20:56:00,642 INFO [jit_pretrained.py:154] Decoding started
|
||||
2023-08-16 20:56:00,727 INFO [jit_pretrained.py:190]
|
||||
download/waves_yesno/0_0_0_1_0_0_0_1.wav:
|
||||
NO NO NO YES NO NO NO YES
|
||||
|
||||
download/waves_yesno/0_0_1_0_0_0_1_0.wav:
|
||||
NO NO YES NO NO NO YES NO
|
||||
|
||||
|
||||
2023-08-16 20:56:00,727 INFO [jit_pretrained.py:192] Decoding Done
|
||||
|
||||
.. hint::
|
||||
|
||||
We provide only code for ``torch.jit.script()``. You can try ``torch.jit.trace()``
|
||||
if you want.
|
||||
|
||||
Export via torch.onnx.export()
|
||||
------------------------------
|
||||
|
||||
The command for this kind of export is
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd /tmp/icefall
|
||||
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
||||
cd egs/yesno/ASR
|
||||
|
||||
# tdnn/export_onnx.py requires onnx and onnxruntime
|
||||
pip install onnx onnxruntime
|
||||
|
||||
# assume that "--epoch 14 --avg 2" produces the lowest WER.
|
||||
|
||||
./tdnn/export_onnx.py \
|
||||
--epoch 14 \
|
||||
--avg 2
|
||||
|
||||
The output logs are given below:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
2023-08-16 20:59:20,888 INFO [export_onnx.py:83] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'epoch': 14, 'avg': 2}
|
||||
2023-08-16 20:59:20,888 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
||||
2023-08-16 20:59:20,892 INFO [export_onnx.py:100] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
|
||||
================ Diagnostic Run torch.onnx.export version 2.0.0 ================
|
||||
verbose: False, log level: Level.ERROR
|
||||
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
|
||||
|
||||
2023-08-16 20:59:21,047 INFO [export_onnx.py:127] Saved to tdnn/exp/model-epoch-14-avg-2.onnx
|
||||
2023-08-16 20:59:21,047 INFO [export_onnx.py:136] meta_data: {'model_type': 'tdnn', 'version': '1', 'model_author': 'k2-fsa', 'comment': 'non-streaming tdnn for the yesno recipe', 'vocab_size': 4}
|
||||
2023-08-16 20:59:21,049 INFO [export_onnx.py:140] Generate int8 quantization models
|
||||
2023-08-16 20:59:21,075 INFO [onnx_quantizer.py:538] Quantization parameters for tensor:"/Transpose_1_output_0" not specified
|
||||
2023-08-16 20:59:21,081 INFO [export_onnx.py:151] Saved to tdnn/exp/model-epoch-14-avg-2.int8.onnx
|
||||
|
||||
We can see from the logs that it generates two files:
|
||||
|
||||
- ``tdnn/exp/model-epoch-14-avg-2.onnx`` (ONNX model with ``float32`` weights)
|
||||
- ``tdnn/exp/model-epoch-14-avg-2.int8.onnx`` (ONNX model with ``int8`` weights)
|
||||
|
||||
To use the generated ONNX model files for decoding with `onnxruntime`_, we can use
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# ./tdnn/onnx_pretrained.py requires kaldifeat
|
||||
#
|
||||
# Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html
|
||||
# for how to install kaldifeat
|
||||
|
||||
pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
|
||||
|
||||
./tdnn/onnx_pretrained.py \
|
||||
--nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \
|
||||
--HLG ./data/lang_phone/HLG.pt \
|
||||
--words-file ./data/lang_phone/words.txt \
|
||||
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
||||
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
||||
|
||||
The output is given below:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
2023-08-16 21:03:24,260 INFO [onnx_pretrained.py:166] {'feature_dim': 23, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'nn_model': './tdnn/exp/model-epoch-14-avg-2.onnx', 'words_file': './data/lang_phone/words.txt', 'HLG': './data/lang_phone/HLG.pt', 'sound_files': ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']}
|
||||
2023-08-16 21:03:24,260 INFO [onnx_pretrained.py:171] device: cpu
|
||||
2023-08-16 21:03:24,260 INFO [onnx_pretrained.py:173] Loading onnx model ./tdnn/exp/model-epoch-14-avg-2.onnx
|
||||
2023-08-16 21:03:24,267 INFO [onnx_pretrained.py:176] Loading HLG from ./data/lang_phone/HLG.pt
|
||||
2023-08-16 21:03:24,270 INFO [onnx_pretrained.py:180] Constructing Fbank computer
|
||||
2023-08-16 21:03:24,273 INFO [onnx_pretrained.py:190] Reading sound files: ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']
|
||||
2023-08-16 21:03:24,279 INFO [onnx_pretrained.py:196] Decoding started
|
||||
2023-08-16 21:03:24,318 INFO [onnx_pretrained.py:232]
|
||||
download/waves_yesno/0_0_0_1_0_0_0_1.wav:
|
||||
NO NO NO YES NO NO NO YES
|
||||
|
||||
download/waves_yesno/0_0_1_0_0_0_1_0.wav:
|
||||
NO NO YES NO NO NO YES NO
|
||||
|
||||
|
||||
2023-08-16 21:03:24,318 INFO [onnx_pretrained.py:234] Decoding Done
|
||||
|
||||
.. note::
|
||||
|
||||
To use the ``int8`` ONNX model for decoding, please use:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
./tdnn/onnx_pretrained.py \
|
||||
--nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \
|
||||
--HLG ./data/lang_phone/HLG.pt \
|
||||
--words-file ./data/lang_phone/words.txt \
|
||||
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
||||
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
||||
|
||||
For the more curious
|
||||
--------------------
|
||||
|
||||
If you are wondering how to deploy the model without ``torch``, please
|
||||
continue reading. We will show how to use `sherpa-onnx`_ to run the
|
||||
exported ONNX models, which depends only on `onnxruntime`_ and does not
|
||||
depend on ``torch``.
|
||||
|
||||
In this tutorial, we will only demonstrate the usage of `sherpa-onnx`_ with the
|
||||
pre-trained model of the `yesno`_ recipe. There are also other two frameworks
|
||||
available:
|
||||
|
||||
- `sherpa`_. It works with torchscript models.
|
||||
- `sherpa-ncnn`_. It works with models exported using :ref:`icefall_export_to_ncnn` with `ncnn`_
|
||||
|
||||
Please see `<https://k2-fsa.github.io/sherpa/>`_ for further details.
|
39
docs/source/for-dummies/training.rst
Normal file
39
docs/source/for-dummies/training.rst
Normal file
@ -0,0 +1,39 @@
|
||||
.. _dummies_tutorial_training:
|
||||
|
||||
Training
|
||||
========
|
||||
|
||||
After :ref:`dummies_tutorial_data_preparation`, we can start training.
|
||||
|
||||
The command to start the training is quite simple:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd /tmp/icefall
|
||||
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
||||
cd egs/yesno/ASR
|
||||
|
||||
# We use CPU for training by setting the following environment variable
|
||||
export CUDA_VISIBLE_DEVICES=""
|
||||
|
||||
./tdnn/train.py
|
||||
|
||||
That's it!
|
||||
|
||||
You can find the training logs below:
|
||||
|
||||
.. literalinclude:: ./code/train-yesno.txt
|
||||
|
||||
For the more curious
|
||||
--------------------
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
./tdnn/train.py --help
|
||||
|
||||
will print the usage information about ``./tdnn/train.py``. For instance, you
|
||||
can specify the number of epochs to train and the location to save the training
|
||||
results.
|
||||
|
||||
The training text logs are saved in ``tdnn/exp/log`` while the tensorboard
|
||||
logs are in ``tdnn/exp/tensorboard``.
|
@ -20,6 +20,7 @@ speech recognition recipes using `k2 <https://github.com/k2-fsa/k2>`_.
|
||||
:maxdepth: 2
|
||||
:caption: Contents:
|
||||
|
||||
for-dummies/index.rst
|
||||
installation/index
|
||||
docker/index
|
||||
faqs
|
||||
|
@ -41,7 +41,7 @@ as an example.
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
@ -78,7 +78,7 @@ In each recipe, there is also a file ``pretrained.py``, which can use
|
||||
|
||||
./pruned_transducer_stateless3/pretrained.py \
|
||||
--checkpoint ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/pretrained-iter-1224000-avg-14.pt \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/bpe.model \
|
||||
--tokens ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
||||
--method greedy_search \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav \
|
||||
|
@ -153,11 +153,10 @@ Next, we use the following code to export our model:
|
||||
|
||||
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
||||
--exp-dir $dir/exp \
|
||||
--bpe-model $dir/data/lang_bpe_500/bpe.model \
|
||||
--tokens $dir/data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 1 \
|
||||
--use-averaged-model 0 \
|
||||
\
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
--cnn-module-kernel 31 \
|
||||
|
@ -73,7 +73,7 @@ Next, we use the following code to export our model:
|
||||
|
||||
./lstm_transducer_stateless2/export-for-ncnn.py \
|
||||
--exp-dir $dir/exp \
|
||||
--bpe-model $dir/data/lang_bpe_500/bpe.model \
|
||||
--tokens $dir/data/lang_bpe_500/tokens.txt \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model 0 \
|
||||
|
@ -72,12 +72,11 @@ Next, we use the following code to export our model:
|
||||
dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
||||
--bpe-model $dir/data/lang_bpe_500/bpe.model \
|
||||
--tokens $dir/data/lang_bpe_500/tokens.txt \
|
||||
--exp-dir $dir/exp \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
\
|
||||
--decode-chunk-len 32 \
|
||||
--num-left-chunks 4 \
|
||||
--num-encoder-layers "2,4,3,2,4" \
|
||||
|
@ -71,7 +71,7 @@ Export the model to ONNX
|
||||
.. code-block:: bash
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
|
@ -32,7 +32,7 @@ as an example in the following.
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--jit 1
|
||||
|
@ -33,7 +33,7 @@ as an example in the following.
|
||||
|
||||
./lstm_transducer_stateless2/export.py \
|
||||
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--iter $iter \
|
||||
--avg $avg \
|
||||
--jit-trace 1
|
||||
|
@ -37,7 +37,7 @@ from lhotse.dataset import (
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
@ -292,8 +292,8 @@ class Aidatatang_200zhAsrDataModule:
|
||||
buffer_size=50000,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
|
@ -322,6 +322,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -151,12 +151,14 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -170,6 +172,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -30,7 +30,7 @@ from lhotse.dataset import (
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
@ -278,8 +278,8 @@ class AishellAsrDataModule:
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
|
@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
@ -299,8 +299,8 @@ class AiShell2AsrDataModule:
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
|
@ -30,7 +30,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 for AudioSamples
|
||||
@ -310,8 +310,8 @@ class Aishell4AsrDataModule:
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
|
@ -37,7 +37,7 @@ from lhotse.dataset import (
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
@ -292,8 +292,8 @@ class AlimeetingAsrDataModule:
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
|
@ -257,7 +257,7 @@ class AmiAsrDataModule:
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
|
@ -30,7 +30,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
@ -311,8 +311,8 @@ class CommonVoiceAsrDataModule:
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
|
@ -330,6 +330,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -152,12 +152,14 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -171,6 +173,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
@ -339,8 +339,8 @@ class CSJAsrDataModule:
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
|
@ -27,7 +27,7 @@ from lhotse.dataset import (
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
@ -264,8 +264,8 @@ class GigaSpeechAsrDataModule:
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
|
@ -30,7 +30,7 @@ from lhotse.dataset import (
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
@ -297,8 +297,8 @@ class GigaSpeechAsrDataModule:
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
|
@ -259,7 +259,7 @@ class LibriCssAsrDataModule:
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
|
@ -1 +0,0 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py
|
1576
egs/libricss/SURT/dprnn_zipformer/scaling.py
Normal file
1576
egs/libricss/SURT/dprnn_zipformer/scaling.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -79,7 +79,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
# ln -sfv /path/to/rirs_noises $dl_dir/
|
||||
#
|
||||
if [ ! -d $dl_dir/rirs_noises ]; then
|
||||
lhotse download rirs_noises $dl_dir
|
||||
lhotse download rir-noise $dl_dir/rirs_noises
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -89,6 +89,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
# to $dl_dir/librispeech. We perform text normalization for the transcripts.
|
||||
# NOTE: Alignments are required for this recipe.
|
||||
mkdir -p data/manifests
|
||||
|
||||
lhotse prepare librispeech -p train-clean-100 -p train-clean-360 -p train-other-500 -p dev-clean \
|
||||
-j 4 --alignments-dir $dl_dir/libri_alignments/LibriSpeech $dl_dir/librispeech data/manifests/
|
||||
fi
|
||||
@ -112,7 +113,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
|
||||
# We assume that you have downloaded the RIRS_NOISES corpus
|
||||
# to $dl_dir/rirs_noises
|
||||
lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises data/manifests
|
||||
lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises/RIRS_NOISES data/manifests
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
|
@ -401,6 +401,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -136,6 +136,7 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
@ -184,6 +185,7 @@ class OnnxModel:
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -197,6 +199,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -359,6 +359,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -356,6 +356,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -129,6 +129,7 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
@ -166,6 +167,7 @@ class OnnxModel:
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -179,6 +181,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -172,30 +172,35 @@ class Model:
|
||||
self.encoder = ort.InferenceSession(
|
||||
args.encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_decoder(self, args):
|
||||
self.decoder = ort.InferenceSession(
|
||||
args.decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_joiner(self, args):
|
||||
self.joiner = ort.InferenceSession(
|
||||
args.joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_joiner_encoder_proj(self, args):
|
||||
self.joiner_encoder_proj = ort.InferenceSession(
|
||||
args.joiner_encoder_proj_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_joiner_decoder_proj(self, args):
|
||||
self.joiner_decoder_proj = ort.InferenceSession(
|
||||
args.joiner_decoder_proj_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
|
@ -31,7 +31,7 @@ from lhotse.dataset import (
|
||||
CutMix,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
@ -290,8 +290,8 @@ class LibriSpeechAsrDataModule:
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
|
@ -307,6 +307,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -1008,7 +1008,7 @@ def modified_beam_search(
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
ys=[-1] * (context_size - 1) + [blank_id],
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
context_state=None if context_graph is None else context_graph.root,
|
||||
timestamp=[],
|
||||
@ -1217,7 +1217,7 @@ def modified_beam_search_lm_rescore(
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
ys=[-1] * (context_size - 1) + [blank_id],
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
timestamp=[],
|
||||
)
|
||||
@ -1417,7 +1417,7 @@ def modified_beam_search_lm_rescore_LODR(
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
ys=[-1] * (context_size - 1) + [blank_id],
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
timestamp=[],
|
||||
)
|
||||
@ -1617,7 +1617,7 @@ def _deprecated_modified_beam_search(
|
||||
B = HypothesisList()
|
||||
B.add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
ys=[-1] * (context_size - 1) + [blank_id],
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
timestamp=[],
|
||||
)
|
||||
@ -1753,7 +1753,11 @@ def beam_search(
|
||||
t = 0
|
||||
|
||||
B = HypothesisList()
|
||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[]))
|
||||
B.add(
|
||||
Hypothesis(
|
||||
ys=[-1] * (context_size - 1) + [blank_id], log_prob=0.0, timestamp=[]
|
||||
)
|
||||
)
|
||||
|
||||
max_sym_per_utt = 20000
|
||||
|
||||
@ -2265,7 +2269,7 @@ def modified_beam_search_ngram_rescoring(
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
ys=[-1] * (context_size - 1) + [blank_id],
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
state_cost=NgramLmStateCost(ngram_lm),
|
||||
)
|
||||
@ -2385,6 +2389,7 @@ def modified_beam_search_LODR(
|
||||
LODR_lm_scale: float,
|
||||
LM: LmScorer,
|
||||
beam: int = 4,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> List[List[int]]:
|
||||
"""This function implements LODR (https://arxiv.org/abs/2203.16776) with
|
||||
`modified_beam_search`. It uses a bi-gram language model as the estimate
|
||||
@ -2446,13 +2451,14 @@ def modified_beam_search_LODR(
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
ys=[-1] * (context_size - 1) + [blank_id],
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
state=init_states, # state of the NN LM
|
||||
lm_score=init_score.reshape(-1),
|
||||
state_cost=NgramLmStateCost(
|
||||
LODR_lm
|
||||
), # state of the source domain ngram
|
||||
context_state=None if context_graph is None else context_graph.root,
|
||||
)
|
||||
)
|
||||
|
||||
@ -2598,8 +2604,17 @@ def modified_beam_search_LODR(
|
||||
|
||||
hyp_log_prob = topk_log_probs[k] # get score of current hyp
|
||||
new_token = topk_token_indexes[k]
|
||||
|
||||
context_score = 0
|
||||
new_context_state = None if context_graph is None else hyp.context_state
|
||||
if new_token not in (blank_id, unk_id):
|
||||
|
||||
if context_graph is not None:
|
||||
(
|
||||
context_score,
|
||||
new_context_state,
|
||||
) = context_graph.forward_one_step(hyp.context_state, new_token)
|
||||
|
||||
ys.append(new_token)
|
||||
state_cost = hyp.state_cost.forward_one_step(new_token)
|
||||
|
||||
@ -2615,6 +2630,7 @@ def modified_beam_search_LODR(
|
||||
hyp_log_prob += (
|
||||
lm_score[new_token] * lm_scale
|
||||
+ LODR_lm_scale * current_ngram_score
|
||||
+ context_score
|
||||
) # add the lm score
|
||||
|
||||
lm_score = scores[count]
|
||||
@ -2633,10 +2649,31 @@ def modified_beam_search_LODR(
|
||||
state=state,
|
||||
lm_score=lm_score,
|
||||
state_cost=state_cost,
|
||||
context_state=new_context_state,
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
B = B + finalized_B
|
||||
|
||||
# finalize context_state, if the matched contexts do not reach final state
|
||||
# we need to add the score on the corresponding backoff arc
|
||||
if context_graph is not None:
|
||||
finalized_B = [HypothesisList() for _ in range(len(B))]
|
||||
for i, hyps in enumerate(B):
|
||||
for hyp in list(hyps):
|
||||
context_score, new_context_state = context_graph.finalize(
|
||||
hyp.context_state
|
||||
)
|
||||
finalized_B[i].add(
|
||||
Hypothesis(
|
||||
ys=hyp.ys,
|
||||
log_prob=hyp.log_prob + context_score,
|
||||
timestamp=hyp.timestamp,
|
||||
context_state=new_context_state,
|
||||
)
|
||||
)
|
||||
B = finalized_B
|
||||
|
||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||
|
||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||
@ -2709,7 +2746,7 @@ def modified_beam_search_lm_shallow_fusion(
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
ys=[-1] * (context_size - 1) + [blank_id],
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
state=init_states,
|
||||
lm_score=init_score.reshape(-1),
|
||||
|
@ -312,6 +312,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -150,12 +150,14 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -169,6 +171,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -78,6 +78,7 @@ def test_conv2d_subsampling():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -133,6 +134,7 @@ def test_rel_pos():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -220,6 +222,7 @@ def test_conformer_encoder_layer():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -304,6 +307,7 @@ def test_conformer_encoder():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -359,6 +363,7 @@ def test_conformer():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
|
@ -404,6 +404,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -335,6 +335,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -138,6 +138,7 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
@ -185,6 +186,7 @@ class OnnxModel:
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -198,6 +200,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -26,7 +26,7 @@ You can generate the checkpoint with the following command:
|
||||
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9
|
||||
|
||||
@ -52,12 +52,12 @@ import torch
|
||||
import torch.nn as nn
|
||||
from alignment import batch_force_alignment
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.utils import AttributeDict, convert_timestamp, parse_timestamp
|
||||
from lhotse import CutSet
|
||||
from lhotse.serialization import SequentialJsonlWriter
|
||||
from lhotse.supervision import AlignmentItem
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.utils import AttributeDict, convert_timestamp, parse_timestamp
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -71,6 +71,10 @@ class Decoder(nn.Module):
|
||||
groups=decoder_dim // 4, # group size == 4
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
|
||||
# when inference with torch.jit.script and context_size == 1
|
||||
self.conv = nn.Identity()
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
|
@ -329,6 +329,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -30,7 +30,7 @@ from lhotse.dataset import (
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
@ -297,8 +297,8 @@ class GigaSpeechAsrDataModule:
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
|
@ -74,6 +74,7 @@ def test_conv2d_subsampling():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -128,6 +129,7 @@ def test_rel_pos():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -204,6 +206,7 @@ def test_zipformer_encoder_layer():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -284,6 +287,7 @@ def test_zipformer_encoder():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -338,6 +342,7 @@ def test_zipformer():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
|
@ -326,41 +326,49 @@ def main():
|
||||
encoder = ort.InferenceSession(
|
||||
args.encoder_model_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder = ort.InferenceSession(
|
||||
args.decoder_model_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner = ort.InferenceSession(
|
||||
args.joiner_model_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_encoder_proj = ort.InferenceSession(
|
||||
args.joiner_encoder_proj_model_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_decoder_proj = ort.InferenceSession(
|
||||
args.joiner_decoder_proj_model_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
lconv = ort.InferenceSession(
|
||||
args.lconv_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
frame_reducer = ort.InferenceSession(
|
||||
args.frame_reducer_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
ctc_output = ort.InferenceSession(
|
||||
args.ctc_output_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
|
@ -413,6 +413,7 @@ def export_decoder_model_onnx(
|
||||
context_size = decoder_model.decoder.context_size
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -401,6 +401,7 @@ def export_decoder_model_onnx(
|
||||
context_size = decoder_model.decoder.context_size
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -130,6 +130,7 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
@ -229,6 +230,7 @@ class OnnxModel:
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -242,6 +244,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -865,7 +865,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
return final_dropout_rate
|
||||
else:
|
||||
return initial_dropout_rate - (
|
||||
initial_dropout_rate * final_dropout_rate
|
||||
initial_dropout_rate - final_dropout_rate
|
||||
) * (self.batch_count / warmup_period)
|
||||
|
||||
def forward(
|
||||
|
@ -230,7 +230,7 @@ class Conformer(Transformer):
|
||||
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
|
||||
) # (T, B, F)
|
||||
else:
|
||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
|
||||
x = self.encoder(x, pos_emb, src_key_padding_mask=src_key_padding_mask) # (T, B, F)
|
||||
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
|
@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
@ -314,8 +314,8 @@ class LibriSpeechAsrDataModule:
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
|
@ -97,6 +97,7 @@ Usage:
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -122,7 +123,7 @@ from beam_search import (
|
||||
)
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall import LmScorer, NgramLm
|
||||
from icefall import ContextGraph, LmScorer, NgramLm
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -215,6 +216,7 @@ def get_parser():
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- modified_beam_search_LODR
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
@ -251,7 +253,7 @@ def get_parser():
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
Used only when --decoding-method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
@ -285,7 +287,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
Used only when --decoding-method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -347,6 +349,27 @@ def get_parser():
|
||||
help="ID of the backoff symbol in the ngram LM",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-score",
|
||||
type=float,
|
||||
default=2,
|
||||
help="""
|
||||
The bonus score of each token for the context biasing words/phrases.
|
||||
Used only when --decoding-method is modified_beam_search and
|
||||
modified_beam_search_LODR.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The path of the context biasing lists, one word/phrase each line
|
||||
Used only when --decoding-method is modified_beam_search and
|
||||
modified_beam_search_LODR.
|
||||
""",
|
||||
)
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -359,6 +382,7 @@ def decode_one_batch(
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
LM: Optional[LmScorer] = None,
|
||||
ngram_lm=None,
|
||||
ngram_lm_scale: float = 0.0,
|
||||
@ -388,7 +412,7 @@ def decode_one_batch(
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
LM:
|
||||
A neural network language model.
|
||||
@ -493,6 +517,7 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -515,6 +540,7 @@ def decode_one_batch(
|
||||
LODR_lm=ngram_lm,
|
||||
LODR_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -578,16 +604,22 @@ def decode_one_batch(
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
elif params.decoding_method in (
|
||||
"modified_beam_search_lm_rescore",
|
||||
"modified_beam_search_lm_rescore_LODR",
|
||||
):
|
||||
ans = dict()
|
||||
assert ans_dict is not None
|
||||
for key, hyps in ans_dict.items():
|
||||
hyps = [sp.decode(hyp).split() for hyp in hyps]
|
||||
ans[f"beam_size_{params.beam_size}_{key}"] = hyps
|
||||
return ans
|
||||
elif "modified_beam_search" in params.decoding_method:
|
||||
prefix = f"beam_size_{params.beam_size}"
|
||||
if params.decoding_method in (
|
||||
"modified_beam_search_lm_rescore",
|
||||
"modified_beam_search_lm_rescore_LODR",
|
||||
):
|
||||
ans = dict()
|
||||
assert ans_dict is not None
|
||||
for key, hyps in ans_dict.items():
|
||||
hyps = [sp.decode(hyp).split() for hyp in hyps]
|
||||
ans[f"{prefix}_{key}"] = hyps
|
||||
return ans
|
||||
else:
|
||||
if params.has_contexts:
|
||||
prefix += f"-context-score-{params.context_score}"
|
||||
return {prefix: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
@ -599,6 +631,7 @@ def decode_dataset(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
LM: Optional[LmScorer] = None,
|
||||
ngram_lm=None,
|
||||
ngram_lm_scale: float = 0.0,
|
||||
@ -618,7 +651,7 @@ def decode_dataset(
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
@ -649,6 +682,7 @@ def decode_dataset(
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
LM=LM,
|
||||
@ -744,6 +778,11 @@ def main():
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if os.path.exists(params.context_file):
|
||||
params.has_contexts = True
|
||||
else:
|
||||
params.has_contexts = False
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
@ -770,6 +809,12 @@ def main():
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
if params.decoding_method in (
|
||||
"modified_beam_search",
|
||||
"modified_beam_search_LODR",
|
||||
):
|
||||
if params.has_contexts:
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -952,6 +997,18 @@ def main():
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
if "modified_beam_search" in params.decoding_method:
|
||||
if os.path.exists(params.context_file):
|
||||
contexts = []
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts.append(line.strip())
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(sp.encode(contexts))
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
context_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -976,6 +1033,7 @@ def main():
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
LM=LM,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
|
@ -506,6 +506,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -353,6 +353,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -146,6 +146,7 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
@ -236,6 +237,7 @@ class OnnxModel:
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -249,6 +251,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -151,12 +151,14 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -170,6 +172,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -17,7 +17,7 @@ from lhotse.dataset import (
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
@ -270,8 +270,8 @@ class MGB2AsrDataModule:
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
|
39
egs/multi_zh-hans/ASR/README.md
Normal file
39
egs/multi_zh-hans/ASR/README.md
Normal file
@ -0,0 +1,39 @@
|
||||
|
||||
# Introduction
|
||||
|
||||
This recipe includes scripts for training Zipformer model using multiple Chinese datasets.
|
||||
|
||||
# Included Training Sets
|
||||
1. THCHS-30
|
||||
2. AiShell-{1,2,4}
|
||||
3. ST-CMDS
|
||||
4. Primewords
|
||||
5. MagicData
|
||||
6. Aidatatang_200zh
|
||||
7. AliMeeting
|
||||
8. WeNetSpeech
|
||||
9. KeSpeech-ASR
|
||||
|
||||
|Datset| Number of hours| URL|
|
||||
|---|---:|---|
|
||||
|**TOTAL**|14,106|---|
|
||||
|THCHS-30|35|https://www.openslr.org/18/|
|
||||
|AiShell-1|170|https://www.openslr.org/33/|
|
||||
|AiShell-2|1,000|http://www.aishelltech.com/aishell_2|
|
||||
|AiShell-4|120|https://www.openslr.org/111/|
|
||||
|ST-CMDS|110|https://www.openslr.org/38/|
|
||||
|Primewords|99|https://www.openslr.org/47/|
|
||||
|aidatatang_200zh|200|https://www.openslr.org/62/|
|
||||
|MagicData|755|https://www.openslr.org/68/|
|
||||
|AliMeeting|100|https://openslr.org/119/|
|
||||
|WeNetSpeech|10,000|https://github.com/wenet-e2e/WenetSpeech|
|
||||
|KeSpeech|1,542|https://github.com/KeSpeech/KeSpeech|
|
||||
|
||||
|
||||
# Included Test Sets
|
||||
1. Aishell-{1,2,4}
|
||||
2. Aidatatang_200zh
|
||||
3. AliMeeting
|
||||
4. MagicData
|
||||
5. KeSpeech-ASR
|
||||
6. WeNetSpeech
|
38
egs/multi_zh-hans/ASR/RESULTS.md
Normal file
38
egs/multi_zh-hans/ASR/RESULTS.md
Normal file
@ -0,0 +1,38 @@
|
||||
## Results
|
||||
|
||||
### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model
|
||||
|
||||
This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall.
|
||||
|
||||
#### Non-streaming
|
||||
|
||||
Best results (num of params : ~69M):
|
||||
|
||||
The training command:
|
||||
|
||||
```
|
||||
./zipformer/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 20 \
|
||||
--use-fp16 1 \
|
||||
--max-duration 600 \
|
||||
--num-workers 8
|
||||
```
|
||||
|
||||
The decoding command:
|
||||
|
||||
```
|
||||
./zipformer/decode.py \
|
||||
--epoch 20 \
|
||||
--avg 1
|
||||
```
|
||||
|
||||
Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model ( # tokens is 2000, byte fallback enabled).
|
||||
|
||||
| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|
||||
|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
|
||||
| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
|
||||
| | 3.2 | 3.67 | 23.15 | 24.78 | 2.91 | 3.04 | 3.59 | 4.03 | 15.68 | 3.68 | 3.12 | 6.69 | 3.19 | 8.01 | 9.32 | 7.05 | 8.78 |
|
||||
|
||||
|
||||
The pre-trained model is available here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2
|
37
egs/multi_zh-hans/ASR/local/bpe_model_to_tokens.py
Executable file
37
egs/multi_zh-hans/ASR/local/bpe_model_to_tokens.py
Executable file
@ -0,0 +1,37 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This script takes `bpe.model` as input and generates a file `tokens.txt`
|
||||
from it.
|
||||
|
||||
Usage:
|
||||
./bpe_model_to_tokens.py /path/to/input/bpe.model > tokens.txt
|
||||
"""
|
||||
import argparse
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"bpe_model",
|
||||
type=str,
|
||||
help="Path to the input bpe.model",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
for i in range(sp.vocab_size()):
|
||||
print(sp.id_to_piece(i), i)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/multi_zh-hans/ASR/local/compile_lg.py
Symbolic link
1
egs/multi_zh-hans/ASR/local/compile_lg.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/compile_lg.py
|
93
egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py
Executable file
93
egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py
Executable file
@ -0,0 +1,93 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
|
||||
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
|
||||
# Copyright 2023 Xiaomi Corp. (Zengrui Jin)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_kespeech_dev_test():
|
||||
in_out_dir = Path("data/fbank/kespeech")
|
||||
# number of workers in dataloader
|
||||
num_workers = 42
|
||||
|
||||
# number of seconds in a batch
|
||||
batch_duration = 600
|
||||
|
||||
subsets = (
|
||||
"dev_phase1",
|
||||
"dev_phase2",
|
||||
"test",
|
||||
)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
for partition in subsets:
|
||||
cuts_path = in_out_dir / f"kespeech-asr_cuts_{partition}.jsonl.gz"
|
||||
if cuts_path.is_file():
|
||||
logging.info(f"{cuts_path} exists - skipping")
|
||||
continue
|
||||
|
||||
raw_cuts_path = in_out_dir / f"kespeech-asr_cuts_{partition}_raw.jsonl.gz"
|
||||
|
||||
logging.info(f"Loading {raw_cuts_path}")
|
||||
cut_set = CutSet.from_file(raw_cuts_path)
|
||||
|
||||
logging.info("Splitting cuts into smaller chunks")
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False, min_duration=None
|
||||
)
|
||||
|
||||
logging.info("Computing features")
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=f"{in_out_dir}/feats_{partition}",
|
||||
num_workers=num_workers,
|
||||
batch_duration=batch_duration,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
cut_set.to_file(cuts_path)
|
||||
|
||||
|
||||
def main():
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_kespeech_dev_test()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
180
egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py
Executable file
180
egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py
Executable file
@ -0,0 +1,180 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
|
||||
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
|
||||
# Copyright 2023 Xiaomi Corp. (Zengrui Jin)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
set_audio_duration_mismatch_tolerance,
|
||||
set_caching_enabled,
|
||||
)
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--training-subset",
|
||||
type=str,
|
||||
default="train_phase1",
|
||||
choices=["train_phase1", "train_phase2"],
|
||||
help="The training subset for computing fbank feature.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of dataloading workers used for reading the audio.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-duration",
|
||||
type=float,
|
||||
default=600.0,
|
||||
help="The maximum number of audio seconds in a batch."
|
||||
"Determines batch size dynamically.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-splits",
|
||||
type=int,
|
||||
required=True,
|
||||
help="The number of splits of the given subset",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Process pieces starting from this number (inclusive).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--stop",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Stop processing pieces until this number (exclusive).",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def compute_fbank_kespeech_splits(args):
|
||||
subset = args.training_subset
|
||||
subset = str(subset)
|
||||
num_splits = args.num_splits
|
||||
output_dir = f"data/fbank/kespeech/{subset}_split_{num_splits}"
|
||||
output_dir = Path(output_dir)
|
||||
assert output_dir.exists(), f"{output_dir} does not exist!"
|
||||
|
||||
num_digits = len(str(num_splits))
|
||||
|
||||
start = args.start
|
||||
stop = args.stop
|
||||
if stop < start:
|
||||
stop = num_splits
|
||||
|
||||
stop = min(stop, num_splits)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance
|
||||
set_caching_enabled(False)
|
||||
for i in range(start, stop):
|
||||
idx = f"{i + 1}".zfill(num_digits)
|
||||
logging.info(f"Processing {idx}/{num_splits}")
|
||||
|
||||
cuts_path = output_dir / f"kespeech-asr_cuts_{subset}.{idx}.jsonl.gz"
|
||||
if cuts_path.is_file():
|
||||
logging.info(f"{cuts_path} exists - skipping")
|
||||
continue
|
||||
|
||||
raw_cuts_path = output_dir / f"kespeech-asr_cuts_{subset}_raw.{idx}.jsonl.gz"
|
||||
|
||||
logging.info(f"Loading {raw_cuts_path}")
|
||||
cut_set = CutSet.from_file(raw_cuts_path)
|
||||
|
||||
logging.info("Splitting cuts into smaller chunks.")
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False, min_duration=None
|
||||
)
|
||||
|
||||
logging.info("Computing features")
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/feats_{subset}_{idx}",
|
||||
num_workers=args.num_workers,
|
||||
batch_duration=args.batch_duration,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
cut_set.to_file(cuts_path)
|
||||
|
||||
|
||||
def main():
|
||||
now = datetime.now()
|
||||
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
|
||||
|
||||
log_filename = "log-compute_fbank_kespeech_splits"
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
log_filename = f"{log_filename}-{date_time}"
|
||||
|
||||
logging.basicConfig(
|
||||
filename=log_filename,
|
||||
format=formatter,
|
||||
level=logging.INFO,
|
||||
filemode="w",
|
||||
)
|
||||
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(logging.INFO)
|
||||
console.setFormatter(logging.Formatter(formatter))
|
||||
logging.getLogger("").addHandler(console)
|
||||
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
compute_fbank_kespeech_splits(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
122
egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py
Executable file
122
egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py
Executable file
@ -0,0 +1,122 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Zengrui Jin)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the MagicData dataset.
|
||||
It looks for manifests in the directory data/manifests/magicdata.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False):
|
||||
src_dir = Path("data/manifests/magicdata")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(30, os.cpu_count())
|
||||
|
||||
dataset_parts = ("train", "test", "dev")
|
||||
prefix = "magicdata"
|
||||
suffix = "jsonl.gz"
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
assert len(manifests) == len(dataset_parts), (
|
||||
len(manifests),
|
||||
len(dataset_parts),
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
continue
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and speed_perturb:
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--num-mel-bins",
|
||||
type=int,
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speed-perturb",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_magicdata(
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
|
||||
)
|
122
egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py
Executable file
122
egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py
Executable file
@ -0,0 +1,122 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Zengrui Jin)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the Primewords dataset.
|
||||
It looks for manifests in the directory data/manifests/primewords.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False):
|
||||
src_dir = Path("data/manifests/primewords")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
|
||||
dataset_parts = ("train",)
|
||||
prefix = "primewords"
|
||||
suffix = "jsonl.gz"
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
assert len(manifests) == len(dataset_parts), (
|
||||
len(manifests),
|
||||
len(dataset_parts),
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
continue
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and speed_perturb:
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--num-mel-bins",
|
||||
type=int,
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speed-perturb",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_primewords(
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
|
||||
)
|
121
egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py
Executable file
121
egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py
Executable file
@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Zengrui Jin)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the ST-CMDS dataset.
|
||||
It looks for manifests in the directory data/manifests/stcmds.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False):
|
||||
src_dir = Path("data/manifests/stcmds")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
|
||||
dataset_parts = ("train",)
|
||||
prefix = "stcmds"
|
||||
suffix = "jsonl.gz"
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
assert len(manifests) == len(dataset_parts), (
|
||||
len(manifests),
|
||||
len(dataset_parts),
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
continue
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and speed_perturb:
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--num-mel-bins",
|
||||
type=int,
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speed-perturb",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_stcmds(
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
|
||||
)
|
127
egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py
Executable file
127
egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py
Executable file
@ -0,0 +1,127 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Zengrui Jin)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the THCHS-30 dataset.
|
||||
It looks for manifests in the directory data/manifests/thchs30.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False):
|
||||
src_dir = Path("data/manifests/thchs30")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
|
||||
dataset_parts = (
|
||||
"train",
|
||||
"dev",
|
||||
"test",
|
||||
)
|
||||
prefix = "thchs_30"
|
||||
suffix = "jsonl.gz"
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
assert len(manifests) == len(dataset_parts), (
|
||||
len(manifests),
|
||||
len(dataset_parts),
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
continue
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition:
|
||||
cut_set = (
|
||||
(cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1))
|
||||
if speed_perturb
|
||||
else cut_set
|
||||
)
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--num-mel-bins",
|
||||
type=int,
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speed-perturb",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_thchs30(
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
|
||||
)
|
1
egs/multi_zh-hans/ASR/local/prepare_char.py
Symbolic link
1
egs/multi_zh-hans/ASR/local/prepare_char.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../wenetspeech/ASR/local/prepare_char.py
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user