mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-30 20:24:18 +00:00
Merge branch 'k2-fsa:master' into fisher_swbd
This commit is contained in:
commit
49f705c4a0
3
.flake8
3
.flake8
@ -9,7 +9,8 @@ per-file-ignores =
|
||||
egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
|
||||
egs/*/ASR/*/optim.py: E501,
|
||||
egs/*/ASR/*/scaling.py: E501,
|
||||
egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203,
|
||||
egs/librispeech/ASR/lstm_transducer_stateless/*.py: E501, E203
|
||||
egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203
|
||||
egs/librispeech/ASR/conformer_ctc2/*py: E501,
|
||||
egs/librispeech/ASR/RESULTS.md: E999,
|
||||
|
||||
|
@ -22,8 +22,80 @@ ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
pushd $repo/exp
|
||||
ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt
|
||||
ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
log "Test exporting to ONNX format"
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--onnx 1
|
||||
|
||||
log "Export to torchscript model"
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--jit 1
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--jit-trace 1
|
||||
|
||||
ls -lh $repo/exp/*.onnx
|
||||
ls -lh $repo/exp/*.pt
|
||||
|
||||
log "Decode with ONNX models"
|
||||
|
||||
./pruned_transducer_stateless3/onnx_check.py \
|
||||
--jit-filename $repo/exp/cpu_jit.pt \
|
||||
--onnx-encoder-filename $repo/exp/encoder.onnx \
|
||||
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
||||
--onnx-joiner-filename $repo/exp/joiner.onnx
|
||||
|
||||
./pruned_transducer_stateless3/onnx_check_all_in_one.py \
|
||||
--jit-filename $repo/exp/cpu_jit.pt \
|
||||
--onnx-all-in-one-filename $repo/exp/all_in_one.onnx
|
||||
|
||||
./pruned_transducer_stateless3/onnx_pretrained.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--encoder-model-filename $repo/exp/encoder.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner.onnx \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
|
||||
log "Decode with models exported by torch.jit.trace()"
|
||||
|
||||
./pruned_transducer_stateless3/jit_pretrained.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--encoder-model-filename $repo/exp/encoder_jit_trace.pt \
|
||||
--decoder-model-filename $repo/exp/decoder_jit_trace.pt \
|
||||
--joiner-model-filename $repo/exp/joiner_jit_trace.pt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
|
||||
log "Decode with models exported by torch.jit.script()"
|
||||
|
||||
./pruned_transducer_stateless3/jit_pretrained.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--encoder-model-filename $repo/exp/encoder_jit_script.pt \
|
||||
--decoder-model-filename $repo/exp/decoder_jit_script.pt \
|
||||
--joiner-model-filename $repo/exp/joiner_jit_script.pt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
|
||||
|
||||
for sym in 1 2 3; do
|
||||
log "Greedy search with --max-sym-per-frame $sym"
|
||||
|
||||
|
@ -70,7 +70,7 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
|
||||
max_duration=100
|
||||
|
||||
for method in greedy_search fast_beam_search modified_beam_search; do
|
||||
log "Decoding with $method"
|
||||
log "Simulate streaming decoding with $method"
|
||||
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--decoding-method $method \
|
||||
@ -82,5 +82,19 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
|
||||
--causal-convolution 1
|
||||
done
|
||||
|
||||
for method in greedy_search fast_beam_search modified_beam_search; do
|
||||
log "Real streaming decoding with $method"
|
||||
|
||||
./pruned_transducer_stateless2/streaming_decode.py \
|
||||
--decoding-method $method \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
--num-decode-streams 100 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--left-context 32 \
|
||||
--decode-chunk-size 8 \
|
||||
--right-context 0
|
||||
done
|
||||
|
||||
rm pruned_transducer_stateless2/exp/*.pt
|
||||
fi
|
||||
|
65
.github/workflows/build-doc.yml
vendored
Normal file
65
.github/workflows/build-doc.yml
vendored
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright 2022 Xiaomi Corp. (author: Fangjun Kuang)
|
||||
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# refer to https://github.com/actions/starter-workflows/pull/47/files
|
||||
|
||||
# You can access it at https://k2-fsa.github.io/icefall/
|
||||
name: Generate doc
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- doc
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
|
||||
jobs:
|
||||
build-doc:
|
||||
if: github.event.label.name == 'doc' || github.event_name == 'push'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: ["3.8"]
|
||||
steps:
|
||||
# refer to https://github.com/actions/checkout
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Display Python version
|
||||
run: python -c "import sys; print(sys.version)"
|
||||
|
||||
- name: Build doc
|
||||
shell: bash
|
||||
run: |
|
||||
cd docs
|
||||
python3 -m pip install -r ./requirements.txt
|
||||
make html
|
||||
touch build/html/.nojekyll
|
||||
|
||||
- name: Deploy
|
||||
uses: peaceiris/actions-gh-pages@v3
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
publish_dir: ./docs/build/html
|
||||
publish_branch: gh-pages
|
@ -35,7 +35,7 @@ on:
|
||||
|
||||
jobs:
|
||||
run_librispeech_pruned_transducer_stateless3_2022_05_13:
|
||||
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
|
2
.github/workflows/style_check.yml
vendored
2
.github/workflows/style_check.yml
vendored
@ -29,7 +29,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04, macos-10.15]
|
||||
os: [ubuntu-18.04, macos-latest]
|
||||
python-version: [3.7, 3.9]
|
||||
fail-fast: false
|
||||
|
||||
|
21
README.md
21
README.md
@ -10,6 +10,10 @@ using <https://github.com/k2-fsa/k2>.
|
||||
You can use <https://github.com/k2-fsa/sherpa> to deploy models
|
||||
trained with icefall.
|
||||
|
||||
You can try pre-trained models from within your browser without the need
|
||||
to download or install anything by visiting <https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition>
|
||||
See <https://k2-fsa.github.io/icefall/huggingface/spaces.html> for more details.
|
||||
|
||||
## Installation
|
||||
|
||||
Please refer to <https://icefall.readthedocs.io/en/latest/installation/index.html>
|
||||
@ -246,17 +250,25 @@ We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless mod
|
||||
|
||||
### WenetSpeech
|
||||
|
||||
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless2].
|
||||
We provide some models for this recipe: [Pruned stateless RNN-T_2: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless2] and [Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless5].
|
||||
|
||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset)
|
||||
#### Pruned stateless RNN-T_2: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset, offline ASR)
|
||||
|
||||
| | Dev | Test-Net | Test-Meeting |
|
||||
|----------------------|-------|----------|--------------|
|
||||
| greedy search | 7.80 | 8.75 | 13.49 |
|
||||
| modified beam search| 7.76 | 8.71 | 13.41 |
|
||||
| fast beam search | 7.94 | 8.74 | 13.80 |
|
||||
| modified beam search | 7.76 | 8.71 | 13.41 |
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing)
|
||||
#### Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset)
|
||||
**Streaming**:
|
||||
| | Dev | Test-Net | Test-Meeting |
|
||||
|----------------------|-------|----------|--------------|
|
||||
| greedy_search | 8.78 | 10.12 | 16.16 |
|
||||
| modified_beam_search | 8.53| 9.95 | 15.81 |
|
||||
| fast_beam_search| 9.01 | 10.47 | 16.28 |
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless2 model: [](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing)
|
||||
|
||||
### Alimeeting
|
||||
|
||||
@ -329,6 +341,7 @@ Please see: [ or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
"sphinx_rtd_theme",
|
||||
"sphinx.ext.todo",
|
||||
"sphinx_rtd_theme",
|
||||
"sphinxcontrib.youtube",
|
||||
]
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
|
13
docs/source/huggingface/index.rst
Normal file
13
docs/source/huggingface/index.rst
Normal file
@ -0,0 +1,13 @@
|
||||
Huggingface
|
||||
===========
|
||||
|
||||
This section describes how to find pre-trained models.
|
||||
It also demonstrates how to try them from within your browser
|
||||
without installing anything by using
|
||||
`Huggingface spaces <https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition>`_.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
pretrained-models
|
||||
spaces
|
BIN
docs/source/huggingface/pic/hugging-face-sherpa-2.png
Normal file
BIN
docs/source/huggingface/pic/hugging-face-sherpa-2.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 455 KiB |
BIN
docs/source/huggingface/pic/hugging-face-sherpa-3.png
Normal file
BIN
docs/source/huggingface/pic/hugging-face-sherpa-3.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 392 KiB |
BIN
docs/source/huggingface/pic/hugging-face-sherpa.png
Normal file
BIN
docs/source/huggingface/pic/hugging-face-sherpa.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 426 KiB |
17
docs/source/huggingface/pretrained-models.rst
Normal file
17
docs/source/huggingface/pretrained-models.rst
Normal file
@ -0,0 +1,17 @@
|
||||
Pre-trained models
|
||||
==================
|
||||
|
||||
We have uploaded pre-trained models for all recipes in ``icefall``
|
||||
to `<https://huggingface.co/>`_.
|
||||
|
||||
You can find them by visiting the following link:
|
||||
|
||||
`<https://huggingface.co/models?search=icefall>`_.
|
||||
|
||||
You can also find links of pre-trained models for a specific recipe
|
||||
by looking at the corresponding ``RESULTS.md``. For instance:
|
||||
|
||||
- `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>`_
|
||||
- `<https://github.com/k2-fsa/icefall/blob/master/egs/aishell/ASR/RESULTS.md>`_
|
||||
- `<https://github.com/k2-fsa/icefall/blob/master/egs/gigaspeech/ASR/RESULTS.md>`_
|
||||
- `<https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/RESULTS.md>`_
|
65
docs/source/huggingface/spaces.rst
Normal file
65
docs/source/huggingface/spaces.rst
Normal file
@ -0,0 +1,65 @@
|
||||
Huggingface spaces
|
||||
==================
|
||||
|
||||
We have integrated the server framework
|
||||
`sherpa <http://github.com/k2-fsa/sherpa>`_
|
||||
with `Huggingface spaces <https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition>`_
|
||||
so that you can try pre-trained models from within your browser
|
||||
without the need to download or install anything.
|
||||
|
||||
All you need is a browser, which can be run on Windows, macOS, Linux, or even on your
|
||||
iPad and your phone.
|
||||
|
||||
Start your browser and visit the following address:
|
||||
|
||||
`<https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition>`_
|
||||
|
||||
and you will see a page like the following screenshot:
|
||||
|
||||
.. image:: ./pic/hugging-face-sherpa.png
|
||||
:alt: screenshot of `<https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition>`_
|
||||
:target: https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition
|
||||
|
||||
You can:
|
||||
|
||||
1. Select a language for recognition. Currently, we provide pre-trained models
|
||||
from ``icefall`` for the following languages: ``Chinese``, ``English``, and
|
||||
``Chinese+English``.
|
||||
2. After selecting the target language, you can select a pre-trained model
|
||||
corresponding to the language.
|
||||
3. Select the decoding method. Currently, it provides ``greedy search``
|
||||
and ``modified_beam_search``.
|
||||
4. If you selected ``modified_beam_search``, you can choose the number of
|
||||
active paths during the search.
|
||||
5. Either upload a file or record your speech for recognition.
|
||||
6. Click the button ``Submit for recognition``.
|
||||
7. Wait for a moment and you will get the recognition results.
|
||||
|
||||
The following screenshot shows an example when selecting ``Chinese+English``:
|
||||
|
||||
.. image:: ./pic/hugging-face-sherpa-3.png
|
||||
:alt: screenshot of `<https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition>`_
|
||||
:target: https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition
|
||||
|
||||
|
||||
In the bottom part of the page, you can find a table of examples. You can click
|
||||
one of them and then click ``Submit for recognition``.
|
||||
|
||||
.. image:: ./pic/hugging-face-sherpa-2.png
|
||||
:alt: screenshot of `<https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition>`_
|
||||
:target: https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition
|
||||
|
||||
YouTube Video
|
||||
-------------
|
||||
|
||||
We provide the following YouTube video demonstrating how to use
|
||||
`<https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition>`_.
|
||||
|
||||
.. note::
|
||||
|
||||
To get the latest news of `next-gen Kaldi <https://github.com/k2-fsa>`_, please subscribe
|
||||
the following YouTube channel by `Nadira Povey <https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_:
|
||||
|
||||
`<https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_
|
||||
|
||||
.. youtube:: ElN3r9dkKE4
|
@ -23,3 +23,4 @@ speech recognition recipes using `k2 <https://github.com/k2-fsa/k2>`_.
|
||||
installation/index
|
||||
recipes/index
|
||||
contributing/index
|
||||
huggingface/index
|
||||
|
@ -474,3 +474,19 @@ The decoding log is:
|
||||
**Congratulations!** You have successfully setup the environment and have run the first recipe in ``icefall``.
|
||||
|
||||
Have fun with ``icefall``!
|
||||
|
||||
YouTube Video
|
||||
-------------
|
||||
|
||||
We provide the following YouTube video showing how to install ``icefall``.
|
||||
It also shows how to debug various problems that you may encounter while
|
||||
using ``icefall``.
|
||||
|
||||
.. note::
|
||||
|
||||
To get the latest news of `next-gen Kaldi <https://github.com/k2-fsa>`_, please subscribe
|
||||
the following YouTube channel by `Nadira Povey <https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_:
|
||||
|
||||
`<https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_
|
||||
|
||||
.. youtube:: LVmrBD0tLfE
|
||||
|
@ -70,6 +70,17 @@ To run stage 2 to stage 5, use:
|
||||
All generated files by ``./prepare.sh``, e.g., features, lexicon, etc,
|
||||
are saved in ``./data`` directory.
|
||||
|
||||
We provide the following YouTube video showing how to run ``./prepare.sh``.
|
||||
|
||||
.. note::
|
||||
|
||||
To get the latest news of `next-gen Kaldi <https://github.com/k2-fsa>`_, please subscribe
|
||||
the following YouTube channel by `Nadira Povey <https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_:
|
||||
|
||||
`<https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_
|
||||
|
||||
.. youtube:: ofEIoJL-mGM
|
||||
|
||||
|
||||
Training
|
||||
--------
|
||||
|
@ -45,6 +45,16 @@ To run stage 2 to stage 5, use:
|
||||
|
||||
$ ./prepare.sh --stage 2 --stop-stage 5
|
||||
|
||||
We provide the following YouTube video showing how to run ``./prepare.sh``.
|
||||
|
||||
.. note::
|
||||
|
||||
To get the latest news of `next-gen Kaldi <https://github.com/k2-fsa>`_, please subscribe
|
||||
the following YouTube channel by `Nadira Povey <https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_:
|
||||
|
||||
`<https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_
|
||||
|
||||
.. youtube:: ofEIoJL-mGM
|
||||
|
||||
Training
|
||||
--------
|
||||
|
@ -43,7 +43,7 @@ torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
||||
src_dir = Path("data/manifests")
|
||||
src_dir = Path("data/manifests/aidatatang_200zh")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
|
||||
|
@ -50,28 +50,19 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Process aidatatang_200zh"
|
||||
if [ ! -f data/fbank/aidatatang_200zh/.fbank.done ]; then
|
||||
mkdir -p data/fbank/aidatatang_200zh
|
||||
lhotse prepare aidatatang-200zh $dl_dir data/manifests/aidatatang_200zh
|
||||
touch data/fbank/aidatatang_200zh/.fbank.done
|
||||
log "Stage 2: Prepare musan manifest"
|
||||
# We assume that you have downloaded the musan corpus
|
||||
# to data/musan
|
||||
if [ ! -f data/manifests/.manifests.done ]; then
|
||||
log "It may take 6 minutes"
|
||||
mkdir -p data/manifests/
|
||||
lhotse prepare musan $dl_dir/musan data/manifests/
|
||||
touch data/manifests/.manifests.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Prepare musan manifest"
|
||||
# We assume that you have downloaded the musan corpus
|
||||
# to data/musan
|
||||
if [ ! -f data/manifests/.musan_manifests.done ]; then
|
||||
log "It may take 6 minutes"
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare musan $dl_dir/musan data/manifests
|
||||
touch data/manifests/.musan_manifests.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Compute fbank for musan"
|
||||
log "Stage 3: Compute fbank for musan"
|
||||
if [ ! -f data/fbank/.msuan.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_musan.py
|
||||
@ -79,8 +70,8 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Compute fbank for aidatatang_200zh"
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Compute fbank for aidatatang_200zh"
|
||||
if [ ! -f data/fbank/.aidatatang_200zh.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aidatatang_200zh.py
|
||||
@ -88,31 +79,38 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare char based lang"
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Prepare char based lang"
|
||||
lang_char_dir=data/lang_char
|
||||
mkdir -p $lang_char_dir
|
||||
|
||||
# Prepare text.
|
||||
grep "\"text\":" data/manifests/aidatatang_200zh/supervisions_train.json \
|
||||
| sed -e 's/["text:\t ]*//g' | sed 's/,//g' \
|
||||
| ./local/text2token.py -t "char" > $lang_char_dir/text
|
||||
|
||||
# Note: in Linux, you can install jq with the following command:
|
||||
# 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
|
||||
# 2. chmod +x ./jq
|
||||
# 3. cp jq /usr/bin
|
||||
if [ ! -f $lang_char_dir/text ]; then
|
||||
gunzip -c data/manifests/aidatatang_200zh/aidatatang_supervisions_train.jsonl.gz \
|
||||
|jq '.text' |sed -e 's/["text:\t ]*//g' | sed 's/"//g' \
|
||||
| ./local/text2token.py -t "char" > $lang_char_dir/text
|
||||
fi
|
||||
# Prepare words.txt
|
||||
grep "\"text\":" data/manifests/aidatatang_200zh/supervisions_train.json \
|
||||
| sed -e 's/["text:\t]*//g' | sed 's/,//g' \
|
||||
| ./local/text2token.py -t "char" > $lang_char_dir/text_words
|
||||
if [ ! -f $lang_char_dir/text_words ]; then
|
||||
gunzip -c data/manifests/aidatatang_200zh/aidatatang_supervisions_train.jsonl.gz \
|
||||
| jq '.text' | sed -e 's/["text:\t]*//g' | sed 's/"//g' \
|
||||
| ./local/text2token.py -t "char" > $lang_char_dir/text_words
|
||||
fi
|
||||
|
||||
cat $lang_char_dir/text_words | sed 's/ /\n/g' | sort -u | sed '/^$/d' \
|
||||
| uniq > $lang_char_dir/words_no_ids.txt
|
||||
|
||||
if [ ! -f $lang_char_dir/words.txt ]; then
|
||||
./local/prepare_words.py \
|
||||
--input-file $lang_char_dir/words_no_ids.txt
|
||||
--output-file $lang_char_dir/words.txt
|
||||
--input-file $lang_char_dir/words_no_ids.txt \
|
||||
--output-file $lang_char_dir/words.txt
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
|
||||
./local/prepare_char.py
|
||||
fi
|
||||
fi
|
||||
|
||||
|
@ -367,6 +367,7 @@ def decode_dataset(
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
texts = [list(str(text).replace(" ", "")) for text in texts]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -379,8 +380,8 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
this_batch.append((ref_text, hyp_words))
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
this_batch.append((cut_id, ref_text, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -405,6 +406,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -520,61 +522,14 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# Note: Please use "pip install webdataset==0.1.103"
|
||||
# for installing the webdataset.
|
||||
import glob
|
||||
import os
|
||||
|
||||
from lhotse import CutSet
|
||||
from lhotse.dataset.webdataset import export_to_webdataset
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)
|
||||
|
||||
dev = "dev"
|
||||
test = "test"
|
||||
|
||||
if not os.path.exists(f"{dev}/shared-0.tar"):
|
||||
os.makedirs(dev)
|
||||
dev_cuts = aidatatang_200zh.valid_cuts()
|
||||
export_to_webdataset(
|
||||
dev_cuts,
|
||||
output_path=f"{dev}/shared-%d.tar",
|
||||
shard_size=300,
|
||||
)
|
||||
|
||||
if not os.path.exists(f"{test}/shared-0.tar"):
|
||||
os.makedirs(test)
|
||||
test_cuts = aidatatang_200zh.test_cuts()
|
||||
export_to_webdataset(
|
||||
test_cuts,
|
||||
output_path=f"{test}/shared-%d.tar",
|
||||
shard_size=300,
|
||||
)
|
||||
|
||||
dev_shards = [
|
||||
str(path)
|
||||
for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
|
||||
]
|
||||
cuts_dev_webdataset = CutSet.from_webdataset(
|
||||
dev_shards,
|
||||
split_by_worker=True,
|
||||
split_by_node=True,
|
||||
shuffle_shards=True,
|
||||
)
|
||||
|
||||
test_shards = [
|
||||
str(path)
|
||||
for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
|
||||
]
|
||||
cuts_test_webdataset = CutSet.from_webdataset(
|
||||
test_shards,
|
||||
split_by_worker=True,
|
||||
split_by_node=True,
|
||||
shuffle_shards=True,
|
||||
)
|
||||
|
||||
dev_dl = aidatatang_200zh.valid_dataloaders(cuts_dev_webdataset)
|
||||
test_dl = aidatatang_200zh.test_dataloaders(cuts_test_webdataset)
|
||||
dev_cuts = aidatatang_200zh.valid_cuts()
|
||||
test_cuts = aidatatang_200zh.test_cuts()
|
||||
dev_dl = aidatatang_200zh.valid_dataloaders(dev_cuts)
|
||||
test_dl = aidatatang_200zh.test_dataloaders(test_cuts)
|
||||
|
||||
test_sets = ["dev", "test"]
|
||||
test_dl = [dev_dl, test_dl]
|
||||
|
@ -81,6 +81,58 @@ We have a tutorial in [sherpa](https://github.com/k2-fsa/sherpa) about how
|
||||
to use the pre-trained model for non-streaming ASR. See
|
||||
<https://k2-fsa.github.io/sherpa/offline_asr/conformer/aishell.html>
|
||||
|
||||
|
||||
#### Pruned transducer stateless 2
|
||||
|
||||
See https://github.com/k2-fsa/icefall/pull/536
|
||||
|
||||
[./pruned_transducer_stateless2](./pruned_transducer_stateless2)
|
||||
|
||||
It uses pruned RNN-T.
|
||||
|
||||
| | test | dev | comment |
|
||||
| -------------------- | ---- | ---- | -------------------------------------- |
|
||||
| greedy search | 5.20 | 4.78 | --epoch 72 --avg 14 --max-duration 200 |
|
||||
| modified beam search | 5.07 | 4.63 | --epoch 72 --avg 14 --max-duration 200 |
|
||||
| fast beam search | 5.13 | 4.70 | --epoch 72 --avg 14 --max-duration 200 |
|
||||
|
||||
Training command is:
|
||||
|
||||
```bash
|
||||
./prepare.sh
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1"
|
||||
|
||||
./pruned_transducer_stateless2/train.py \
|
||||
--world-size 2 \
|
||||
--num-epochs 90 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--max-duration 200 \
|
||||
```
|
||||
|
||||
The tensorboard log is available at
|
||||
https://tensorboard.dev/experiment/QI3PVzrGRrebxpbWUPwmkA/
|
||||
|
||||
The decoding command is:
|
||||
```bash
|
||||
for m in greedy_search modified_beam_search fast_beam_search ; do
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 72 \
|
||||
--avg 14 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 200 \
|
||||
--decoding-method $m
|
||||
|
||||
done
|
||||
```
|
||||
|
||||
Pretrained models, training logs, decoding logs, and decoding results
|
||||
are available at
|
||||
<https://huggingface.co/teapoly/icefall-aishell-pruned-transducer-stateless2-2022-08-18>
|
||||
|
||||
|
||||
#### 2022-03-01
|
||||
|
||||
[./transducer_stateless_modified-2](./transducer_stateless_modified-2)
|
||||
|
@ -374,6 +374,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -389,9 +390,9 @@ def decode_dataset(
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
@ -419,6 +420,7 @@ def save_results(
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
@ -429,7 +431,9 @@ def save_results(
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((list("".join(res[0])), list("".join(res[1]))))
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
||||
@ -537,6 +541,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aishell = AishellAsrDataModule(args)
|
||||
test_cuts = aishell.test_cuts()
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
|
@ -386,6 +386,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -401,9 +402,9 @@ def decode_dataset(
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
@ -431,6 +432,7 @@ def save_results(
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
@ -441,7 +443,9 @@ def save_results(
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((list("".join(res[0])), list("".join(res[1]))))
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
||||
@ -556,6 +560,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aishell = AishellAsrDataModule(args)
|
||||
test_cuts = aishell.test_cuts()
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
|
@ -48,6 +48,8 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
log "stage -1: Download LM"
|
||||
# We assume that you have installed the git-lfs, if not, you could install it
|
||||
# using: `sudo apt-get install git-lfs && git-lfs install`
|
||||
git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1)
|
||||
|
||||
if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then
|
||||
git clone https://huggingface.co/pkufool/aishell_lm $dl_dir/lm
|
||||
fi
|
||||
|
1
egs/aishell/ASR/pruned_transducer_stateless2/asr_datamodule.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless2/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
||||
../tdnn_lstm_ctc/asr_datamodule.py
|
1
egs/aishell/ASR/pruned_transducer_stateless2/beam_search.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless2/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
|
1
egs/aishell/ASR/pruned_transducer_stateless2/conformer.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless2/conformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/conformer.py
|
573
egs/aishell/ASR/pruned_transducer_stateless2/decode.py
Executable file
573
egs/aishell/ASR/pruned_transducer_stateless2/decode.py
Executable file
@ -0,0 +1,573 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 84 \
|
||||
--avg 25 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 84 \
|
||||
--avg 25 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 84 \
|
||||
--avg 25 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 84 \
|
||||
--avg 25 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AishellAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_char",
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
token_table: k2.SymbolTable,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
token_table:
|
||||
It maps token ID to a string.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
hyp_tokens = []
|
||||
batch_size = encoder_out.size(0)
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyp_tokens.append(hyp)
|
||||
|
||||
hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
return {
|
||||
(
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
): hyps
|
||||
}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
token_table: k2.SymbolTable,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
token_table:
|
||||
It maps a token ID to a string.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
token_table=token_table,
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AishellAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
aishell = AishellAsrDataModule(args)
|
||||
test_cuts = aishell.test_cuts()
|
||||
dev_cuts = aishell.valid_cuts()
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
dev_dl = aishell.test_dataloaders(dev_cuts)
|
||||
|
||||
test_sets = ["test", "dev"]
|
||||
test_dls = [test_dl, dev_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
token_table=lexicon.token_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/aishell/ASR/pruned_transducer_stateless2/decoder.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless2/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py
|
217
egs/aishell/ASR/pruned_transducer_stateless2/export.py
Executable file
217
egs/aishell/ASR/pruned_transducer_stateless2/export.py
Executable file
@ -0,0 +1,217 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
Usage:
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--jit 0 \
|
||||
--epoch 29 \
|
||||
--avg 5
|
||||
|
||||
It will generate a file exp_dir/pretrained-epoch-29-avg-5.pt
|
||||
|
||||
To use the generated file with `pruned_transducer_stateless2/decode.py`,
|
||||
you can do::
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained-epoch-29-avg-5.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/aishell/ASR
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 100 \
|
||||
--lang-dir data/lang_char
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=29,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=Path,
|
||||
default=Path("pruned_transducer_stateless2/exp"),
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default=Path("data/lang_char"),
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = (
|
||||
params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
)
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torch.jit.script")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = (
|
||||
params.exp_dir
|
||||
/ f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
)
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/aishell/ASR/pruned_transducer_stateless2/joiner.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless2/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py
|
1
egs/aishell/ASR/pruned_transducer_stateless2/model.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless2/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/model.py
|
1
egs/aishell/ASR/pruned_transducer_stateless2/optim.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless2/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/optim.py
|
337
egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
Executable file
337
egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
Executable file
@ -0,0 +1,337 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Usage:
|
||||
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless2/pretrained.py \
|
||||
--checkpoint /path/to/pretrained.pt \
|
||||
--lang-dir /path/to/lang_char \
|
||||
--method greedy_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) beam search
|
||||
./pruned_transducer_stateless2/pretrained.py \
|
||||
--checkpoint /path/to/pretrained.pt \
|
||||
--lang-dir /path/to/lang_char \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless2/pretrained.py \
|
||||
--checkpoint /path/to/pretrained.pt \
|
||||
--lang-dir /path/to/lang_char \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(4) fast beam search
|
||||
./pruned_transducer_stateless2/pretrained.py \
|
||||
--checkpoint /path/to/pretrained.pt \
|
||||
--lang-dir /path/to/lang_char \
|
||||
--method fast_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default=Path("data/lang_char"),
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Maximum number of symbols per frame. "
|
||||
"Use only when --method is greedy_search",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lens = [f.size(0) for f in features]
|
||||
feature_lens = torch.tensor(feature_lens, device=device)
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features, x_lens=feature_lens
|
||||
)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyp_list = []
|
||||
logging.info(f"Using {params.method}")
|
||||
|
||||
if params.method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
hyp_list = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_list = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_list = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
for i in range(num_waves):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.method}"
|
||||
)
|
||||
hyp_list.append(hyp)
|
||||
|
||||
hyps = []
|
||||
for hyp in hyp_list:
|
||||
hyps.append([lexicon.token_table[i] for i in hyp])
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/aishell/ASR/pruned_transducer_stateless2/scaling.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless2/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py
|
1057
egs/aishell/ASR/pruned_transducer_stateless2/train.py
Executable file
1057
egs/aishell/ASR/pruned_transducer_stateless2/train.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -377,6 +377,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -389,9 +390,9 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -416,6 +417,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -427,7 +429,9 @@ def save_results(
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((list("".join(res[0])), list("".join(res[1]))))
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
@ -464,6 +468,7 @@ def main():
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params.datatang_prob = 0
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
@ -605,6 +610,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
asr_datamodule = AsrDataModule(args)
|
||||
aishell = AIShell(manifest_dir=args.manifest_dir)
|
||||
test_cuts = aishell.test_cuts()
|
||||
|
@ -157,6 +157,7 @@ def main():
|
||||
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
params.datatang_prob = 0
|
||||
|
||||
logging.info(params)
|
||||
|
||||
|
@ -223,6 +223,7 @@ def main():
|
||||
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
params.datatang_prob = 0
|
||||
|
||||
logging.info(params)
|
||||
|
||||
|
@ -22,8 +22,12 @@
|
||||
Usage:
|
||||
|
||||
./prepare.sh
|
||||
|
||||
# If you use a non-zero value for --datatang-prob, you also need to run
|
||||
./prepare_aidatatang_200zh.sh
|
||||
|
||||
If you use --datatang-prob=0, then you don't need to run the above script.
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
|
||||
@ -343,9 +347,12 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--datatang-prob",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="The probability to select a batch from the "
|
||||
"aidatatang_200zh dataset",
|
||||
default=0.0,
|
||||
help="""The probability to select a batch from the
|
||||
aidatatang_200zh dataset.
|
||||
If it is set to 0, you don't need to download the data
|
||||
for aidatatang_200zh.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
@ -457,8 +464,12 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
||||
decoder_datatang = get_decoder_model(params)
|
||||
joiner_datatang = get_joiner_model(params)
|
||||
if params.datatang_prob > 0:
|
||||
decoder_datatang = get_decoder_model(params)
|
||||
joiner_datatang = get_joiner_model(params)
|
||||
else:
|
||||
decoder_datatang = None
|
||||
joiner_datatang = None
|
||||
|
||||
model = Transducer(
|
||||
encoder=encoder,
|
||||
@ -726,7 +737,7 @@ def train_one_epoch(
|
||||
scheduler: LRSchedulerType,
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
datatang_train_dl: torch.utils.data.DataLoader,
|
||||
datatang_train_dl: Optional[torch.utils.data.DataLoader],
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
rng: random.Random,
|
||||
scaler: GradScaler,
|
||||
@ -778,13 +789,17 @@ def train_one_epoch(
|
||||
dl_weights = [1 - params.datatang_prob, params.datatang_prob]
|
||||
|
||||
iter_aishell = iter(train_dl)
|
||||
iter_datatang = iter(datatang_train_dl)
|
||||
if datatang_train_dl is not None:
|
||||
iter_datatang = iter(datatang_train_dl)
|
||||
|
||||
batch_idx = 0
|
||||
|
||||
while True:
|
||||
idx = rng.choices((0, 1), weights=dl_weights, k=1)[0]
|
||||
dl = iter_aishell if idx == 0 else iter_datatang
|
||||
if datatang_train_dl is not None:
|
||||
idx = rng.choices((0, 1), weights=dl_weights, k=1)[0]
|
||||
dl = iter_aishell if idx == 0 else iter_datatang
|
||||
else:
|
||||
dl = iter_aishell
|
||||
|
||||
try:
|
||||
batch = next(dl)
|
||||
@ -808,7 +823,11 @@ def train_one_epoch(
|
||||
warmup=(params.batch_idx_train / params.model_warm_step),
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
if datatang_train_dl is not None:
|
||||
tot_loss = (
|
||||
tot_loss * (1 - 1 / params.reset_interval)
|
||||
) + loss_info
|
||||
|
||||
if aishell:
|
||||
aishell_tot_loss = (
|
||||
aishell_tot_loss * (1 - 1 / params.reset_interval)
|
||||
@ -871,12 +890,21 @@ def train_one_epoch(
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
if datatang_train_dl is not None:
|
||||
datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], "
|
||||
tot_loss_str = (
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
)
|
||||
else:
|
||||
tot_loss_str = ""
|
||||
datatang_str = ""
|
||||
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, {prefix}_loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
f"{tot_loss_str}"
|
||||
f"aishell_tot_loss[{aishell_tot_loss}], "
|
||||
f"datatang_tot_loss[{datatang_tot_loss}], "
|
||||
f"{datatang_str}"
|
||||
f"batch size: {batch_size}, "
|
||||
f"lr: {cur_lr:.2e}"
|
||||
)
|
||||
@ -891,15 +919,18 @@ def train_one_epoch(
|
||||
f"train/current_{prefix}_",
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
if datatang_train_dl is not None:
|
||||
# If it is None, tot_loss is the same as aishell_tot_loss.
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
aishell_tot_loss.write_summary(
|
||||
tb_writer, "train/aishell_tot_", params.batch_idx_train
|
||||
)
|
||||
datatang_tot_loss.write_summary(
|
||||
tb_writer, "train/datatang_tot_", params.batch_idx_train
|
||||
)
|
||||
if datatang_train_dl is not None:
|
||||
datatang_tot_loss.write_summary(
|
||||
tb_writer, "train/datatang_tot_", params.batch_idx_train
|
||||
)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
@ -917,7 +948,10 @@ def train_one_epoch(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
if datatang_train_dl is not None:
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
else:
|
||||
loss_value = aishell_tot_loss["loss"] / aishell_tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
@ -1004,7 +1038,16 @@ def run(rank, world_size, args):
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
if params.datatang_prob > 0:
|
||||
find_unused_parameters = True
|
||||
else:
|
||||
find_unused_parameters = False
|
||||
|
||||
model = DDP(
|
||||
model,
|
||||
device_ids=[rank],
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
)
|
||||
|
||||
optimizer = Eve(model.parameters(), lr=params.initial_lr)
|
||||
|
||||
@ -1032,11 +1075,6 @@ def run(rank, world_size, args):
|
||||
train_cuts = aishell.train_cuts()
|
||||
train_cuts = filter_short_and_long_utterances(train_cuts)
|
||||
|
||||
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
|
||||
train_datatang_cuts = datatang.train_cuts()
|
||||
train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts)
|
||||
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
|
||||
|
||||
if args.enable_musan:
|
||||
cuts_musan = load_manifest(
|
||||
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
|
||||
@ -1052,11 +1090,21 @@ def run(rank, world_size, args):
|
||||
cuts_musan=cuts_musan,
|
||||
)
|
||||
|
||||
datatang_train_dl = asr_datamodule.train_dataloaders(
|
||||
train_datatang_cuts,
|
||||
on_the_fly_feats=False,
|
||||
cuts_musan=cuts_musan,
|
||||
)
|
||||
if params.datatang_prob > 0:
|
||||
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
|
||||
train_datatang_cuts = datatang.train_cuts()
|
||||
train_datatang_cuts = filter_short_and_long_utterances(
|
||||
train_datatang_cuts
|
||||
)
|
||||
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
|
||||
datatang_train_dl = asr_datamodule.train_dataloaders(
|
||||
train_datatang_cuts,
|
||||
on_the_fly_feats=False,
|
||||
cuts_musan=cuts_musan,
|
||||
)
|
||||
else:
|
||||
datatang_train_dl = None
|
||||
logging.info("Not using aidatatang_200zh for training")
|
||||
|
||||
valid_cuts = aishell.valid_cuts()
|
||||
valid_dl = asr_datamodule.valid_dataloaders(valid_cuts)
|
||||
@ -1065,13 +1113,14 @@ def run(rank, world_size, args):
|
||||
train_dl,
|
||||
# datatang_train_dl
|
||||
]:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=dl,
|
||||
optimizer=optimizer,
|
||||
graph_compiler=graph_compiler,
|
||||
params=params,
|
||||
)
|
||||
if dl is not None:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=dl,
|
||||
optimizer=optimizer,
|
||||
graph_compiler=graph_compiler,
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
@ -1083,7 +1132,8 @@ def run(rank, world_size, args):
|
||||
scheduler.step_epoch(epoch - 1)
|
||||
fix_random_seed(params.seed + epoch - 1)
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
datatang_train_dl.sampler.set_epoch(epoch)
|
||||
if datatang_train_dl is not None:
|
||||
datatang_train_dl.sampler.set_epoch(epoch)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
@ -241,6 +241,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -253,9 +254,9 @@ def decode_dataset(
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
@ -278,6 +279,7 @@ def save_results(
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -287,7 +289,9 @@ def save_results(
|
||||
# We compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((list("".join(res[0])), list("".join(res[1]))))
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(f, f"{test_set_name}-{key}", results_char)
|
||||
test_set_wers[key] = wer
|
||||
@ -365,6 +369,8 @@ def main():
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aishell = AishellAsrDataModule(args)
|
||||
test_cuts = aishell.test_cuts()
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
|
@ -38,8 +38,8 @@ from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
@ -296,6 +296,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -307,9 +308,9 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -334,6 +335,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
@ -344,7 +346,9 @@ def save_results(
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((list("".join(res[0])), list("".join(res[1]))))
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
@ -438,6 +442,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aishell = AishellAsrDataModule(args)
|
||||
test_cuts = aishell.test_cuts()
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
|
@ -341,6 +341,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -353,9 +354,9 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -380,6 +381,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -391,7 +393,9 @@ def save_results(
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((list("".join(res[0])), list("".join(res[1]))))
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
@ -496,6 +500,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
asr_datamodule = AsrDataModule(args)
|
||||
aishell = AIShell(manifest_dir=args.manifest_dir)
|
||||
test_cuts = aishell.test_cuts()
|
||||
|
@ -345,6 +345,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -357,9 +358,9 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -384,6 +385,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -395,7 +397,9 @@ def save_results(
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((list("".join(res[0])), list("".join(res[1]))))
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
@ -498,6 +502,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aishell = AishellAsrDataModule(args)
|
||||
test_cuts = aishell.test_cuts()
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
|
@ -514,6 +514,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -527,8 +528,8 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
this_batch.append((ref_text, hyp_words))
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
this_batch.append((cut_id, ref_text, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -553,6 +554,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -756,6 +758,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aishell2 = AiShell2AsrDataModule(args)
|
||||
|
||||
valid_cuts = aishell2.valid_cuts()
|
||||
|
@ -378,6 +378,7 @@ def decode_dataset(
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
texts = [list(str(text).replace(" ", "")) for text in texts]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -390,8 +391,8 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
this_batch.append((ref_text, hyp_words))
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
this_batch.append((cut_id, ref_text, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -416,6 +417,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -607,6 +609,8 @@ def main():
|
||||
c.supervisions[0].text = text_normalize(text)
|
||||
return c
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aishell4 = Aishell4AsrDataModule(args)
|
||||
test_cuts = aishell4.test_cuts()
|
||||
test_cuts = test_cuts.map(text_normalize_for_cut)
|
||||
|
@ -367,6 +367,7 @@ def decode_dataset(
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
texts = [list(str(text).replace(" ", "")) for text in texts]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -379,8 +380,8 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
this_batch.append((ref_text, hyp_words))
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
this_batch.append((cut_id, ref_text, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -405,6 +406,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -535,6 +537,8 @@ def main():
|
||||
from lhotse import CutSet
|
||||
from lhotse.dataset.webdataset import export_to_webdataset
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
alimeeting = AlimeetingAsrDataModule(args)
|
||||
|
||||
dev = "eval"
|
||||
|
@ -451,6 +451,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -469,9 +470,9 @@ def decode_dataset(
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
else:
|
||||
@ -512,6 +513,7 @@ def save_results(
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
results = post_processing(results)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
@ -676,6 +678,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
gigaspeech = GigaSpeechAsrDataModule(args)
|
||||
|
||||
dev_cuts = gigaspeech.dev_cuts()
|
||||
|
@ -20,11 +20,7 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
)
|
||||
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -69,6 +65,7 @@ def compute_fbank_gigaspeech_dev_test():
|
||||
storage_path=f"{in_out_dir}/feats_{partition}",
|
||||
num_workers=num_workers,
|
||||
batch_duration=batch_duration,
|
||||
overwrite=True,
|
||||
)
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False, min_duration=None
|
||||
|
@ -22,11 +22,7 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
)
|
||||
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -120,6 +116,7 @@ def compute_fbank_gigaspeech_splits(args):
|
||||
storage_path=f"{output_dir}/feats_XL_{idx}",
|
||||
num_workers=args.num_workers,
|
||||
batch_duration=args.batch_duration,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
logging.info("About to split cuts into smaller chunks.")
|
||||
|
@ -374,6 +374,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -386,9 +387,9 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -414,6 +415,7 @@ def save_results(
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = post_processing(results)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -544,6 +546,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
gigaspeech = GigaSpeechAsrDataModule(args)
|
||||
|
||||
dev_cuts = gigaspeech.dev_cuts()
|
||||
|
@ -25,6 +25,7 @@ The following table lists the differences among them.
|
||||
| `pruned_stateless_emformer_rnnt2` | Emformer(from torchaudio) | Embedding + Conv1d | Using Emformer from torchaudio for streaming ASR|
|
||||
| `conv_emformer_transducer_stateless` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer for streaming ASR + mechanisms in reworked model |
|
||||
| `conv_emformer_transducer_stateless2` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer with simplified memory for streaming ASR + mechanisms in reworked model |
|
||||
| `lstm_transducer_stateless` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model |
|
||||
|
||||
The decoder in `transducer_stateless` is modified from the paper
|
||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||
|
@ -1,5 +1,91 @@
|
||||
## Results
|
||||
|
||||
#### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T)
|
||||
|
||||
[lstm_transducer_stateless](./lstm_transducer_stateless)
|
||||
|
||||
It implements LSTM model with mechanisms in reworked model for streaming ASR.
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/479> for more details.
|
||||
|
||||
#### training on full librispeech
|
||||
|
||||
This model contains 12 encoder layers (LSTM module + Feedforward module). The number of model parameters is 84689496.
|
||||
|
||||
The WERs are:
|
||||
|
||||
| | test-clean | test-other | comment | decoding mode |
|
||||
|-------------------------------------|------------|------------|----------------------|----------------------|
|
||||
| greedy search (max sym per frame 1) | 3.81 | 9.73 | --epoch 35 --avg 15 | simulated streaming |
|
||||
| greedy search (max sym per frame 1) | 3.78 | 9.79 | --epoch 35 --avg 15 | streaming |
|
||||
| fast beam search | 3.74 | 9.59 | --epoch 35 --avg 15 | simulated streaming |
|
||||
| fast beam search | 3.73 | 9.61 | --epoch 35 --avg 15 | streaming |
|
||||
| modified beam search | 3.64 | 9.55 | --epoch 35 --avg 15 | simulated streaming |
|
||||
| modified beam search | 3.65 | 9.51 | --epoch 35 --avg 15 | streaming |
|
||||
|
||||
Note: `simulated streaming` indicates feeding full utterance during decoding, while `streaming` indicates feeding certain number of frames at each time.
|
||||
|
||||
The training command is:
|
||||
|
||||
```bash
|
||||
./lstm_transducer_stateless/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 35 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir lstm_transducer_stateless/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 500 \
|
||||
--master-port 12321 \
|
||||
--num-encoder-layers 12 \
|
||||
--rnn-hidden-size 1024
|
||||
```
|
||||
|
||||
The tensorboard log can be found at
|
||||
<https://tensorboard.dev/experiment/FWrM20mjTeWo6dTpFYOsYQ/>
|
||||
|
||||
The simulated streaming decoding command using greedy search, fast beam search, and modified beam search is:
|
||||
```bash
|
||||
for decoding_method in greedy_search fast_beam_search modified_beam_search; do
|
||||
./lstm_transducer_stateless/decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir lstm_transducer_stateless/exp \
|
||||
--max-duration 600 \
|
||||
--num-encoder-layers 12 \
|
||||
--rnn-hidden-size 1024 \
|
||||
--decoding-method $decoding_method \
|
||||
--use-averaged-model True \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8 \
|
||||
--beam-size 4
|
||||
done
|
||||
```
|
||||
|
||||
The streaming decoding command using greedy search, fast beam search, and modified beam search is:
|
||||
```bash
|
||||
for decoding_method in greedy_search fast_beam_search modified_beam_search; do
|
||||
./lstm_transducer_stateless/streaming_decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir lstm_transducer_stateless/exp \
|
||||
--max-duration 600 \
|
||||
--num-encoder-layers 12 \
|
||||
--rnn-hidden-size 1024 \
|
||||
--decoding-method $decoding_method \
|
||||
--use-averaged-model True \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8 \
|
||||
--beam-size 4
|
||||
done
|
||||
```
|
||||
|
||||
Pretrained models, training logs, decoding logs, and decoding results
|
||||
are available at
|
||||
<https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18>
|
||||
|
||||
|
||||
#### LibriSpeech BPE training results (Pruned Stateless Conv-Emformer RNN-T 2)
|
||||
|
||||
[conv_emformer_transducer_stateless2](./conv_emformer_transducer_stateless2)
|
||||
@ -618,6 +704,80 @@ done
|
||||
|
||||
Pre-trained models, training and decoding logs, and decoding results are available at <https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless4_20220625>
|
||||
|
||||
#### [pruned_transducer_stateless5](./pruned_transducer_stateless5)
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/454> for more details.
|
||||
|
||||
##### Training on full librispeech
|
||||
The WERs are (the number in the table formatted as test-clean & test-other):
|
||||
|
||||
We only trained 25 epochs for saving time, if you want to get better results you can train more epochs.
|
||||
|
||||
| decoding method | left context | chunk size = 2 | chunk size = 4 | chunk size = 8 | chunk size = 16|
|
||||
|----------------------|--------------|----------------|----------------|----------------|----------------|
|
||||
| greedy search | 32 | 3.93 & 9.88 | 3.64 & 9.43 | 3.51 & 8.92 | 3.26 & 8.37 |
|
||||
| greedy search | 64 | 4.84 & 9.81 | 3.59 & 9.27 | 3.44 & 8.83 | 3.23 & 8.33 |
|
||||
| fast beam search | 32 | 3.86 & 9.77 | 3.67 & 9.3 | 3.5 & 8.83 | 3.27 & 8.33 |
|
||||
| fast beam search | 64 | 3.79 & 9.68 | 3.57 & 9.21 | 3.41 & 8.72 | 3.25 & 8.27 |
|
||||
| modified beam search | 32 | 3.84 & 9.71 | 3.66 & 9.38 | 3.47 & 8.86 | 3.26 & 8.42 |
|
||||
| modified beam search | 64 | 3.81 & 9.59 | 3.58 & 9.2 | 3.44 & 8.74 | 3.23 & 8.35 |
|
||||
|
||||
|
||||
**NOTE:** The WERs in table above were decoded with simulate streaming method (i.e. using masking strategy), see commands below. We also have [real streaming decoding](./pruned_transducer_stateless5/streaming_decode.py) script which should produce almost the same results. We tried adding right context in the real streaming decoding, but it seemed not to benefit the performance for all the models, the reasons might be the training and decoding mismatching.
|
||||
|
||||
The training command is:
|
||||
|
||||
```bash
|
||||
./pruned_transducer_stateless5/train.py \
|
||||
--exp-dir pruned_transducer_stateless5/exp \
|
||||
--num-encoder-layers 18 \
|
||||
--dim-feedforward 2048 \
|
||||
--nhead 8 \
|
||||
--encoder-dim 512 \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--full-libri 1 \
|
||||
--dynamic-chunk-training 1 \
|
||||
--causal-convolution 1 \
|
||||
--short-chunk-size 20 \
|
||||
--num-left-chunks 4 \
|
||||
--max-duration 300 \
|
||||
--world-size 4 \
|
||||
--start-epoch 1 \
|
||||
--num-epochs 25
|
||||
```
|
||||
|
||||
You can find the tensorboard log here <https://tensorboard.dev/experiment/rO04h6vjTLyw0qSxjp4m4Q>
|
||||
|
||||
The decoding command is:
|
||||
```bash
|
||||
decoding_method="greedy_search" # "fast_beam_search", "modified_beam_search"
|
||||
|
||||
for chunk in 2 4 8 16; do
|
||||
for left in 32 64; do
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--num-encoder-layers 18 \
|
||||
--dim-feedforward 2048 \
|
||||
--nhead 8 \
|
||||
--encoder-dim 512 \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--simulate-streaming 1 \
|
||||
--decode-chunk-size ${chunk} \
|
||||
--left-context ${left} \
|
||||
--causal-convolution 1 \
|
||||
--epoch 25 \
|
||||
--avg 3 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--max-sym-per-frame 1 \
|
||||
--max-duration 1000 \
|
||||
--decoding-method ${decoding_method}
|
||||
done
|
||||
done
|
||||
```
|
||||
|
||||
Pre-trained models, training and decoding logs, and decoding results are available at <https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless5_20220729>
|
||||
|
||||
|
||||
### LibriSpeech BPE training results (Pruned Stateless Conv-Emformer RNN-T)
|
||||
|
||||
|
@ -525,6 +525,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -544,9 +545,9 @@ def decode_dataset(
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
else:
|
||||
@ -586,6 +587,7 @@ def save_results(
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
@ -779,6 +781,8 @@ def main():
|
||||
)
|
||||
rnn_lm_model.eval()
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
|
@ -447,6 +447,17 @@ def compute_loss(
|
||||
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
|
||||
info["utterances"] = feature.size(0)
|
||||
# averaged input duration in frames over utterances
|
||||
info["utt_duration"] = supervisions["num_frames"].sum().item()
|
||||
# averaged padding proportion over utterances
|
||||
info["utt_pad_proportion"] = (
|
||||
((feature.size(1) - supervisions["num_frames"]) / feature.size(1))
|
||||
.sum()
|
||||
.item()
|
||||
)
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
|
@ -31,14 +31,13 @@ import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
@ -633,6 +632,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -652,9 +652,9 @@ def decode_dataset(
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
else:
|
||||
@ -694,6 +694,7 @@ def save_results(
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
@ -956,6 +957,8 @@ def main():
|
||||
)
|
||||
rnn_lm_model.eval()
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
|
@ -605,6 +605,15 @@ def compute_loss(
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
|
||||
info["utterances"] = feature.size(0)
|
||||
# averaged input duration in frames over utterances
|
||||
info["utt_duration"] = feature_lens.sum().item()
|
||||
# averaged padding proportion over utterances
|
||||
info["utt_pad_proportion"] = (
|
||||
((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
|
||||
)
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
|
@ -449,6 +449,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -466,9 +467,9 @@ def decode_dataset(
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
@ -496,6 +497,7 @@ def save_results(
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
@ -661,6 +663,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
# CAUTION: `test_sets` is for displaying only.
|
||||
# If you want to skip test-clean, you have to skip
|
||||
|
@ -403,6 +403,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -415,9 +416,9 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -442,6 +443,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -624,6 +626,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
|
@ -29,6 +29,7 @@ class Stream(object):
|
||||
def __init__(
|
||||
self,
|
||||
params: AttributeDict,
|
||||
cut_id: str,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
LOG_EPS: float = math.log(1e-10),
|
||||
@ -44,6 +45,7 @@ class Stream(object):
|
||||
The device to run this stream.
|
||||
"""
|
||||
self.LOG_EPS = LOG_EPS
|
||||
self.cut_id = cut_id
|
||||
|
||||
# Containing attention caches and convolution caches
|
||||
self.states: Optional[
|
||||
@ -138,6 +140,10 @@ class Stream(object):
|
||||
"""Return True if all feature frames are processed."""
|
||||
return self._done
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.cut_id
|
||||
|
||||
def decoding_result(self) -> List[int]:
|
||||
"""Obtain current decoding result."""
|
||||
if self.decoding_method == "greedy_search":
|
||||
|
@ -74,7 +74,6 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
from lhotse import CutSet
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
@ -83,6 +82,7 @@ from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from emformer import LOG_EPSILON, stack_states, unstack_states
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from stream import Stream
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
@ -678,6 +678,7 @@ def decode_dataset(
|
||||
# Each utterance has a Stream.
|
||||
stream = Stream(
|
||||
params=params,
|
||||
cut_id=cut.id,
|
||||
decoding_graph=decoding_graph,
|
||||
device=device,
|
||||
LOG_EPS=LOG_EPSILON,
|
||||
@ -711,6 +712,7 @@ def decode_dataset(
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
(
|
||||
streams[i].id,
|
||||
streams[i].ground_truth.split(),
|
||||
sp.decode(streams[i].decoding_result()).split(),
|
||||
)
|
||||
@ -731,6 +733,7 @@ def decode_dataset(
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
(
|
||||
streams[i].id,
|
||||
streams[i].ground_truth.split(),
|
||||
sp.decode(streams[i].decoding_result()).split(),
|
||||
)
|
||||
|
@ -686,6 +686,15 @@ def compute_loss(
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
||||
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
|
||||
info["utterances"] = feature.size(0)
|
||||
# averaged input duration in frames over utterances
|
||||
info["utt_duration"] = feature_lens.sum().item()
|
||||
# averaged padding proportion over utterances
|
||||
info["utt_pad_proportion"] = (
|
||||
((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||
|
@ -403,6 +403,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -415,9 +416,9 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -442,6 +443,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -624,6 +626,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
|
@ -74,7 +74,6 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
from lhotse import CutSet
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
@ -83,6 +82,7 @@ from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from emformer import LOG_EPSILON, stack_states, unstack_states
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from stream import Stream
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
@ -678,6 +678,7 @@ def decode_dataset(
|
||||
# Each utterance has a Stream.
|
||||
stream = Stream(
|
||||
params=params,
|
||||
cut_id=cut.id,
|
||||
decoding_graph=decoding_graph,
|
||||
device=device,
|
||||
LOG_EPS=LOG_EPSILON,
|
||||
@ -711,6 +712,7 @@ def decode_dataset(
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
(
|
||||
streams[i].id,
|
||||
streams[i].ground_truth.split(),
|
||||
sp.decode(streams[i].decoding_result()).split(),
|
||||
)
|
||||
@ -731,6 +733,7 @@ def decode_dataset(
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
(
|
||||
streams[i].id,
|
||||
streams[i].ground_truth.split(),
|
||||
sp.decode(streams[i].decoding_result()).split(),
|
||||
)
|
||||
|
@ -686,6 +686,15 @@ def compute_loss(
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
||||
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
|
||||
info["utterances"] = feature.size(0)
|
||||
# averaged input duration in frames over utterances
|
||||
info["utt_duration"] = feature_lens.sum().item()
|
||||
# averaged padding proportion over utterances
|
||||
info["utt_pad_proportion"] = (
|
||||
((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||
|
@ -81,9 +81,9 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ] && [ ! "$use_extracted_codebook" ==
|
||||
# or
|
||||
# pip install multi_quantization
|
||||
|
||||
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)")
|
||||
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('multi_quantization') is not None)")
|
||||
if [ $has_quantization == 'False' ]; then
|
||||
log "Please install quantization before running following stages"
|
||||
log "Please install multi_quantization before running following stages"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -68,6 +68,7 @@ def compute_fbank_gigaspeech_dev_test():
|
||||
storage_path=f"{in_out_dir}/{prefix}_feats_{partition}",
|
||||
num_workers=num_workers,
|
||||
batch_duration=batch_duration,
|
||||
overwrite=True,
|
||||
)
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False, min_duration=None
|
||||
|
@ -126,6 +126,7 @@ def compute_fbank_gigaspeech_splits(args):
|
||||
storage_path=f"{output_dir}/{prefix}_feats_XL_{idx}",
|
||||
num_workers=args.num_workers,
|
||||
batch_duration=args.batch_duration,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
logging.info("About to split cuts into smaller chunks.")
|
||||
|
1
egs/librispeech/ASR/lstm_transducer_stateless/__init__.py
Symbolic link
1
egs/librispeech/ASR/lstm_transducer_stateless/__init__.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/__init__.py
|
1
egs/librispeech/ASR/lstm_transducer_stateless/asr_datamodule.py
Symbolic link
1
egs/librispeech/ASR/lstm_transducer_stateless/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/asr_datamodule.py
|
1
egs/librispeech/ASR/lstm_transducer_stateless/beam_search.py
Symbolic link
1
egs/librispeech/ASR/lstm_transducer_stateless/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/beam_search.py
|
818
egs/librispeech/ASR/lstm_transducer_stateless/decode.py
Executable file
818
egs/librispeech/ASR/lstm_transducer_stateless/decode.py
Executable file
@ -0,0 +1,818 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./lstm_transducer_stateless/decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir ./lstm_transducer_stateless/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./lstm_transducer_stateless/decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir ./lstm_transducer_stateless/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./lstm_transducer_stateless/decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir ./lstm_transducer_stateless/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search (one best)
|
||||
./lstm_transducer_stateless/decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir ./lstm_transducer_stateless/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(5) fast beam search (nbest)
|
||||
./lstm_transducer_stateless/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(6) fast beam search (nbest oracle WER)
|
||||
./lstm_transducer_stateless/decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir ./lstm_transducer_stateless/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(7) fast beam search (with LG)
|
||||
./lstm_transducer_stateless/decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir ./lstm_transducer_stateless/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_LG \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="lstm_transducer_stateless/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
# tail padding here to alleviate the tail deletion problem
|
||||
num_tail_padded_frames = 35
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
(0, 0, 0, num_tail_padded_frames),
|
||||
mode="constant",
|
||||
value=LOG_EPS,
|
||||
)
|
||||
feature_lens += num_tail_padded_frames
|
||||
|
||||
encoder_out, encoder_out_lens, _ = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(
|
||||
params.vocab_size - 1, device=device
|
||||
)
|
||||
else:
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/librispeech/ASR/lstm_transducer_stateless/decoder.py
Symbolic link
1
egs/librispeech/ASR/lstm_transducer_stateless/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/decoder.py
|
@ -0,0 +1 @@
|
||||
../transducer_stateless/encoder_interface.py
|
388
egs/librispeech/ASR/lstm_transducer_stateless/export.py
Executable file
388
egs/librispeech/ASR/lstm_transducer_stateless/export.py
Executable file
@ -0,0 +1,388 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Export to torchscript model using torch.jit.trace()
|
||||
|
||||
./lstm_transducer_stateless/export.py \
|
||||
--exp-dir ./lstm_transducer_stateless/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 35 \
|
||||
--avg 10 \
|
||||
--jit-trace 1
|
||||
|
||||
It will generate 3 files: `encoder_jit_trace.pt`,
|
||||
`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`.
|
||||
|
||||
(2) Export `model.state_dict()`
|
||||
|
||||
./lstm_transducer_stateless/export.py \
|
||||
--exp-dir ./lstm_transducer_stateless/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 35 \
|
||||
--avg 10
|
||||
|
||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||
|
||||
To use the generated file with `lstm_transducer_stateless/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
./lstm_transducer_stateless/decode.py \
|
||||
--exp-dir ./lstm_transducer_stateless/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
|
||||
Check ./pretrained.py for its usage.
|
||||
|
||||
Note: If you don't want to train a model from scratch, we have
|
||||
provided one for you. You can get it at
|
||||
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18
|
||||
|
||||
with the following commands:
|
||||
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18
|
||||
# You will find the pre-trained model in icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18/exp
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless3/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit-trace",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.trace.
|
||||
It will generate 3 files:
|
||||
- encoder_jit_trace.pt
|
||||
- decoder_jit_trace.pt
|
||||
- joiner_jit_trace.pt
|
||||
|
||||
Check ./jit_pretrained.py for how to use them.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def export_encoder_model_jit_trace(
|
||||
encoder_model: nn.Module,
|
||||
encoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given encoder model with torch.jit.trace()
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||
states = encoder_model.get_init_states()
|
||||
|
||||
traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
|
||||
traced_model.save(encoder_filename)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_jit_trace(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given decoder model with torch.jit.trace()
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The input decoder model
|
||||
decoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = torch.tensor([False])
|
||||
|
||||
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||
traced_model.save(decoder_filename)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_jit_trace(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
) -> None:
|
||||
"""Export the given joiner model with torch.jit.trace()
|
||||
|
||||
Note: The argument project_input is fixed to True. A user should not
|
||||
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||
will do that for the user.
|
||||
|
||||
Args:
|
||||
joiner_model:
|
||||
The input joiner model
|
||||
joiner_filename:
|
||||
The filename to save the exported model.
|
||||
|
||||
"""
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
||||
traced_model.save(joiner_filename)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit_trace is True:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
logging.info("Using torch.jit.trace()")
|
||||
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
||||
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||
|
||||
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
|
||||
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||
|
||||
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
|
||||
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||
else:
|
||||
logging.info("Not using torchscript")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
322
egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py
Executable file
322
egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py
Executable file
@ -0,0 +1,322 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads torchscript models, either exported by `torch.jit.trace()`
|
||||
or by `torch.jit.script()`, and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./lstm_transducer_stateless/export.py \
|
||||
--exp-dir ./lstm_transducer_stateless/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--jit-trace 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./lstm_transducer_stateless/jit_pretrained.py \
|
||||
--encoder-model-filename ./lstm_transducer_stateless/exp/encoder_jit_trace.pt \
|
||||
--decoder-model-filename ./lstm_transducer_stateless/exp/decoder_jit_trace.pt \
|
||||
--joiner-model-filename ./lstm_transducer_stateless/exp/joiner_jit_trace.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Context size of the decoder model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
decoder: torch.jit.ScriptModule,
|
||||
joiner: torch.jit.ScriptModule,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
context_size: int,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
decoder:
|
||||
The decoder model.
|
||||
joiner:
|
||||
The joiner model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,).
|
||||
context_size:
|
||||
The context size of the decoder model.
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
device = encoder_out.device
|
||||
blank_id = 0 # hard-code to 0
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
).squeeze(1)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = packed_encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
decoder_out = decoder_out[:batch_size]
|
||||
|
||||
logits = joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
)
|
||||
decoder_out = decoder_out.squeeze(1)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
encoder = torch.jit.load(args.encoder_model_filename)
|
||||
decoder = torch.jit.load(args.decoder_model_filename)
|
||||
joiner = torch.jit.load(args.joiner_model_filename)
|
||||
|
||||
encoder.eval()
|
||||
decoder.eval()
|
||||
joiner.eval()
|
||||
|
||||
encoder.to(device)
|
||||
decoder.to(device)
|
||||
joiner.to(device)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = args.sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
expected_sample_rate=args.sample_rate,
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
states = encoder.get_init_states(batch_size=features.size(0), device=device)
|
||||
|
||||
encoder_out, encoder_out_lens, _ = encoder(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
states=states,
|
||||
)
|
||||
|
||||
hyps = greedy_search(
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
context_size=args.context_size,
|
||||
)
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = sp.decode(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/librispeech/ASR/lstm_transducer_stateless/joiner.py
Symbolic link
1
egs/librispeech/ASR/lstm_transducer_stateless/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/joiner.py
|
842
egs/librispeech/ASR/lstm_transducer_stateless/lstm.py
Normal file
842
egs/librispeech/ASR/lstm_transducer_stateless/lstm.py
Normal file
@ -0,0 +1,842 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
DoubleSwish,
|
||||
ScaledConv2d,
|
||||
ScaledLinear,
|
||||
ScaledLSTM,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
LOG_EPSILON = math.log(1e-10)
|
||||
|
||||
|
||||
def unstack_states(
|
||||
states: Tuple[torch.Tensor, torch.Tensor]
|
||||
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Unstack the lstm states corresponding to a batch of utterances into a list
|
||||
of states, where the i-th entry is the state from the i-th utterance.
|
||||
|
||||
Args:
|
||||
states:
|
||||
A tuple of 2 elements.
|
||||
``states[0]`` is the lstm hidden states, of a batch of utterance.
|
||||
``states[1]`` is the lstm cell states, of a batch of utterances.
|
||||
|
||||
Returns:
|
||||
A list of states.
|
||||
``states[i]`` is a tuple of 2 elememts of i-th utterance.
|
||||
``states[i][0]`` is the lstm hidden states of i-th utterance.
|
||||
``states[i][1]`` is the lstm cell states of i-th utterance.
|
||||
"""
|
||||
hidden_states, cell_states = states
|
||||
|
||||
list_hidden_states = hidden_states.unbind(dim=1)
|
||||
list_cell_states = cell_states.unbind(dim=1)
|
||||
|
||||
ans = [
|
||||
(h.unsqueeze(1), c.unsqueeze(1))
|
||||
for (h, c) in zip(list_hidden_states, list_cell_states)
|
||||
]
|
||||
return ans
|
||||
|
||||
|
||||
def stack_states(
|
||||
states_list: List[Tuple[torch.Tensor, torch.Tensor]]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Stack list of lstm states corresponding to separate utterances into a single
|
||||
lstm state so that it can be used as an input for lstm when those utterances
|
||||
are formed into a batch.
|
||||
|
||||
Args:
|
||||
state_list:
|
||||
Each element in state_list corresponds to the lstm state for a single
|
||||
utterance.
|
||||
``states[i]`` is a tuple of 2 elememts of i-th utterance.
|
||||
``states[i][0]`` is the lstm hidden states of i-th utterance.
|
||||
``states[i][1]`` is the lstm cell states of i-th utterance.
|
||||
|
||||
|
||||
Returns:
|
||||
A new state corresponding to a batch of utterances.
|
||||
It is a tuple of 2 elements.
|
||||
``states[0]`` is the lstm hidden states, of a batch of utterance.
|
||||
``states[1]`` is the lstm cell states, of a batch of utterances.
|
||||
"""
|
||||
hidden_states = torch.cat([s[0] for s in states_list], dim=1)
|
||||
cell_states = torch.cat([s[1] for s in states_list], dim=1)
|
||||
ans = (hidden_states, cell_states)
|
||||
return ans
|
||||
|
||||
|
||||
class RNN(EncoderInterface):
|
||||
"""
|
||||
Args:
|
||||
num_features (int):
|
||||
Number of input features.
|
||||
subsampling_factor (int):
|
||||
Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa
|
||||
d_model (int):
|
||||
Output dimension (default=512).
|
||||
dim_feedforward (int):
|
||||
Feedforward dimension (default=2048).
|
||||
rnn_hidden_size (int):
|
||||
Hidden dimension for lstm layers (default=1024).
|
||||
num_encoder_layers (int):
|
||||
Number of encoder layers (default=12).
|
||||
dropout (float):
|
||||
Dropout rate (default=0.1).
|
||||
layer_dropout (float):
|
||||
Dropout value for model-level warmup (default=0.075).
|
||||
aux_layer_period (int):
|
||||
Period of auxiliary layers used for random combiner during training.
|
||||
If set to 0, will not use the random combiner (Default).
|
||||
You can set a positive integer to use the random combiner, e.g., 3.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
subsampling_factor: int = 4,
|
||||
d_model: int = 512,
|
||||
dim_feedforward: int = 2048,
|
||||
rnn_hidden_size: int = 1024,
|
||||
num_encoder_layers: int = 12,
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
aux_layer_period: int = 0,
|
||||
) -> None:
|
||||
super(RNN, self).__init__()
|
||||
|
||||
self.num_features = num_features
|
||||
self.subsampling_factor = subsampling_factor
|
||||
if subsampling_factor != 4:
|
||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||
|
||||
# self.encoder_embed converts the input of shape (N, T, num_features)
|
||||
# to the shape (N, T//subsampling_factor, d_model).
|
||||
# That is, it does two things simultaneously:
|
||||
# (1) subsampling: T -> T//subsampling_factor
|
||||
# (2) embedding: num_features -> d_model
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||
|
||||
self.num_encoder_layers = num_encoder_layers
|
||||
self.d_model = d_model
|
||||
self.rnn_hidden_size = rnn_hidden_size
|
||||
|
||||
encoder_layer = RNNEncoderLayer(
|
||||
d_model=d_model,
|
||||
dim_feedforward=dim_feedforward,
|
||||
rnn_hidden_size=rnn_hidden_size,
|
||||
dropout=dropout,
|
||||
layer_dropout=layer_dropout,
|
||||
)
|
||||
self.encoder = RNNEncoder(
|
||||
encoder_layer,
|
||||
num_encoder_layers,
|
||||
aux_layers=list(
|
||||
range(
|
||||
num_encoder_layers // 3,
|
||||
num_encoder_layers - 1,
|
||||
aux_layer_period,
|
||||
)
|
||||
)
|
||||
if aux_layer_period > 0
|
||||
else None,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
warmup: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is (N, T, C), where N is the batch size,
|
||||
T is the sequence length, C is the feature dimension.
|
||||
x_lens:
|
||||
A tensor of shape (N,), containing the number of frames in `x`
|
||||
before padding.
|
||||
states:
|
||||
A tuple of 2 tensors (optional). It is for streaming inference.
|
||||
states[0] is the hidden states of all layers,
|
||||
with shape of (num_layers, N, d_model);
|
||||
states[1] is the cell states of all layers,
|
||||
with shape of (num_layers, N, rnn_hidden_size).
|
||||
warmup:
|
||||
A floating point value that gradually increases from 0 throughout
|
||||
training; when it is >= 1.0 we are "fully warmed up". It is used
|
||||
to turn modules on sequentially.
|
||||
|
||||
Returns:
|
||||
A tuple of 3 tensors:
|
||||
- embeddings: its shape is (N, T', d_model), where T' is the output
|
||||
sequence lengths.
|
||||
- lengths: a tensor of shape (batch_size,) containing the number of
|
||||
frames in `embeddings` before padding.
|
||||
- updated states, whose shape is the same as the input states.
|
||||
"""
|
||||
x = self.encoder_embed(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
# lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning
|
||||
#
|
||||
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
||||
lengths = (((x_lens - 3) >> 1) - 1) >> 1
|
||||
if not torch.jit.is_tracing():
|
||||
assert x.size(0) == lengths.max().item()
|
||||
|
||||
if states is None:
|
||||
x = self.encoder(x, warmup=warmup)[0]
|
||||
# torch.jit.trace requires returned types to be the same as annotated # noqa
|
||||
new_states = (torch.empty(0), torch.empty(0))
|
||||
else:
|
||||
assert not self.training
|
||||
assert len(states) == 2
|
||||
if not torch.jit.is_tracing():
|
||||
# for hidden state
|
||||
assert states[0].shape == (
|
||||
self.num_encoder_layers,
|
||||
x.size(1),
|
||||
self.d_model,
|
||||
)
|
||||
# for cell state
|
||||
assert states[1].shape == (
|
||||
self.num_encoder_layers,
|
||||
x.size(1),
|
||||
self.rnn_hidden_size,
|
||||
)
|
||||
x, new_states = self.encoder(x, states)
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
||||
return x, lengths, new_states
|
||||
|
||||
@torch.jit.export
|
||||
def get_init_states(
|
||||
self, batch_size: int = 1, device: torch.device = torch.device("cpu")
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Get model initial states."""
|
||||
# for rnn hidden states
|
||||
hidden_states = torch.zeros(
|
||||
(self.num_encoder_layers, batch_size, self.d_model), device=device
|
||||
)
|
||||
cell_states = torch.zeros(
|
||||
(self.num_encoder_layers, batch_size, self.rnn_hidden_size),
|
||||
device=device,
|
||||
)
|
||||
return (hidden_states, cell_states)
|
||||
|
||||
|
||||
class RNNEncoderLayer(nn.Module):
|
||||
"""
|
||||
RNNEncoderLayer is made up of lstm and feedforward networks.
|
||||
|
||||
Args:
|
||||
d_model:
|
||||
The number of expected features in the input (required).
|
||||
dim_feedforward:
|
||||
The dimension of feedforward network model (default=2048).
|
||||
rnn_hidden_size:
|
||||
The hidden dimension of rnn layer.
|
||||
dropout:
|
||||
The dropout value (default=0.1).
|
||||
layer_dropout:
|
||||
The dropout value for model-level warmup (default=0.075).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
dim_feedforward: int,
|
||||
rnn_hidden_size: int,
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
) -> None:
|
||||
super(RNNEncoderLayer, self).__init__()
|
||||
self.layer_dropout = layer_dropout
|
||||
self.d_model = d_model
|
||||
self.rnn_hidden_size = rnn_hidden_size
|
||||
|
||||
assert rnn_hidden_size >= d_model, (rnn_hidden_size, d_model)
|
||||
self.lstm = ScaledLSTM(
|
||||
input_size=d_model,
|
||||
hidden_size=rnn_hidden_size,
|
||||
proj_size=d_model if rnn_hidden_size > d_model else 0,
|
||||
num_layers=1,
|
||||
dropout=0.0,
|
||||
)
|
||||
self.feed_forward = nn.Sequential(
|
||||
ScaledLinear(d_model, dim_feedforward),
|
||||
ActivationBalancer(channel_dim=-1),
|
||||
DoubleSwish(),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||
)
|
||||
self.norm_final = BasicNorm(d_model)
|
||||
|
||||
# try to ensure the output is close to zero-mean (or at least, zero-median). # noqa
|
||||
self.balancer = ActivationBalancer(
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: torch.Tensor,
|
||||
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
warmup: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
src:
|
||||
The sequence to the encoder layer (required).
|
||||
Its shape is (S, N, E), where S is the sequence length,
|
||||
N is the batch size, and E is the feature number.
|
||||
states:
|
||||
A tuple of 2 tensors (optional). It is for streaming inference.
|
||||
states[0] is the hidden states of all layers,
|
||||
with shape of (1, N, d_model);
|
||||
states[1] is the cell states of all layers,
|
||||
with shape of (1, N, rnn_hidden_size).
|
||||
warmup:
|
||||
It controls selective bypass of of layers; if < 1.0, we will
|
||||
bypass layers more frequently.
|
||||
"""
|
||||
src_orig = src
|
||||
|
||||
warmup_scale = min(0.1 + warmup, 1.0)
|
||||
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
|
||||
# completely bypass it.
|
||||
if self.training:
|
||||
alpha = (
|
||||
warmup_scale
|
||||
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
|
||||
else 0.1
|
||||
)
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
# lstm module
|
||||
if states is None:
|
||||
src_lstm = self.lstm(src)[0]
|
||||
# torch.jit.trace requires returned types be the same as annotated
|
||||
new_states = (torch.empty(0), torch.empty(0))
|
||||
else:
|
||||
assert not self.training
|
||||
assert len(states) == 2
|
||||
if not torch.jit.is_tracing():
|
||||
# for hidden state
|
||||
assert states[0].shape == (1, src.size(1), self.d_model)
|
||||
# for cell state
|
||||
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
|
||||
src_lstm, new_states = self.lstm(src, states)
|
||||
src = src + self.dropout(src_lstm)
|
||||
|
||||
# feed forward module
|
||||
src = src + self.dropout(self.feed_forward(src))
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
if alpha != 1.0:
|
||||
src = alpha * src + (1 - alpha) * src_orig
|
||||
|
||||
return src, new_states
|
||||
|
||||
|
||||
class RNNEncoder(nn.Module):
|
||||
"""
|
||||
RNNEncoder is a stack of N encoder layers.
|
||||
|
||||
Args:
|
||||
encoder_layer:
|
||||
An instance of the RNNEncoderLayer() class (required).
|
||||
num_layers:
|
||||
The number of sub-encoder-layers in the encoder (required).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_layer: nn.Module,
|
||||
num_layers: int,
|
||||
aux_layers: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
super(RNNEncoder, self).__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
self.d_model = encoder_layer.d_model
|
||||
self.rnn_hidden_size = encoder_layer.rnn_hidden_size
|
||||
|
||||
self.aux_layers: List[int] = []
|
||||
self.combiner: Optional[nn.Module] = None
|
||||
if aux_layers is not None:
|
||||
assert len(set(aux_layers)) == len(aux_layers)
|
||||
assert num_layers - 1 not in aux_layers
|
||||
self.aux_layers = aux_layers + [num_layers - 1]
|
||||
self.combiner = RandomCombine(
|
||||
num_inputs=len(self.aux_layers),
|
||||
final_weight=0.5,
|
||||
pure_prob=0.333,
|
||||
stddev=2.0,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: torch.Tensor,
|
||||
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
warmup: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Pass the input through the encoder layer in turn.
|
||||
|
||||
Args:
|
||||
src:
|
||||
The sequence to the encoder layer (required).
|
||||
Its shape is (S, N, E), where S is the sequence length,
|
||||
N is the batch size, and E is the feature number.
|
||||
states:
|
||||
A tuple of 2 tensors (optional). It is for streaming inference.
|
||||
states[0] is the hidden states of all layers,
|
||||
with shape of (num_layers, N, d_model);
|
||||
states[1] is the cell states of all layers,
|
||||
with shape of (num_layers, N, rnn_hidden_size).
|
||||
warmup:
|
||||
It controls selective bypass of of layers; if < 1.0, we will
|
||||
bypass layers more frequently.
|
||||
"""
|
||||
if states is not None:
|
||||
assert not self.training
|
||||
assert len(states) == 2
|
||||
if not torch.jit.is_tracing():
|
||||
# for hidden state
|
||||
assert states[0].shape == (
|
||||
self.num_layers,
|
||||
src.size(1),
|
||||
self.d_model,
|
||||
)
|
||||
# for cell state
|
||||
assert states[1].shape == (
|
||||
self.num_layers,
|
||||
src.size(1),
|
||||
self.rnn_hidden_size,
|
||||
)
|
||||
|
||||
output = src
|
||||
|
||||
outputs = []
|
||||
|
||||
new_hidden_states = []
|
||||
new_cell_states = []
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
if states is None:
|
||||
output = mod(output, warmup=warmup)[0]
|
||||
else:
|
||||
layer_state = (
|
||||
states[0][i : i + 1, :, :], # h: (1, N, d_model)
|
||||
states[1][i : i + 1, :, :], # c: (1, N, rnn_hidden_size)
|
||||
)
|
||||
output, (h, c) = mod(output, layer_state)
|
||||
new_hidden_states.append(h)
|
||||
new_cell_states.append(c)
|
||||
|
||||
if self.combiner is not None and i in self.aux_layers:
|
||||
outputs.append(output)
|
||||
|
||||
if self.combiner is not None:
|
||||
output = self.combiner(outputs)
|
||||
|
||||
if states is None:
|
||||
new_states = (torch.empty(0), torch.empty(0))
|
||||
else:
|
||||
new_states = (
|
||||
torch.cat(new_hidden_states, dim=0),
|
||||
torch.cat(new_cell_states, dim=0),
|
||||
)
|
||||
|
||||
return output, new_states
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
|
||||
Convert an input of shape (N, T, idim) to an output
|
||||
with shape (N, T', odim), where
|
||||
T' = ((T-3)//2-1)//2, which approximates T' == T//4
|
||||
|
||||
It is based on
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
layer1_channels: int = 8,
|
||||
layer2_channels: int = 32,
|
||||
layer3_channels: int = 128,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
in_channels:
|
||||
Number of channels in. The input shape is (N, T, in_channels).
|
||||
Caution: It requires: T >= 9, in_channels >= 9.
|
||||
out_channels
|
||||
Output dim. The output shape is (N, ((T-3)//2-1)//2, out_channels)
|
||||
layer1_channels:
|
||||
Number of channels in layer1
|
||||
layer1_channels:
|
||||
Number of channels in layer2
|
||||
"""
|
||||
assert in_channels >= 9
|
||||
super().__init__()
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
ScaledConv2d(
|
||||
in_channels=1,
|
||||
out_channels=layer1_channels,
|
||||
kernel_size=3,
|
||||
padding=0,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer1_channels,
|
||||
out_channels=layer2_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer2_channels,
|
||||
out_channels=layer3_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
)
|
||||
self.out = ScaledLinear(
|
||||
layer3_channels * (((in_channels - 3) // 2 - 1) // 2), out_channels
|
||||
)
|
||||
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
||||
# itself has learned scale, so the extra degree of freedom is not
|
||||
# needed.
|
||||
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||
# constrain median of output to be close to zero.
|
||||
self.out_balancer = ActivationBalancer(
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is (N, T, idim).
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape (N, ((T-3)//2-1)//2, odim)
|
||||
"""
|
||||
# On entry, x is (N, T, idim)
|
||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
x = self.conv(x)
|
||||
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
# Now x is of shape (N, ((T-3)//2-1))//2, odim)
|
||||
x = self.out_norm(x)
|
||||
x = self.out_balancer(x)
|
||||
return x
|
||||
|
||||
|
||||
class RandomCombine(nn.Module):
|
||||
"""
|
||||
This module combines a list of Tensors, all with the same shape, to
|
||||
produce a single output of that same shape which, in training time,
|
||||
is a random combination of all the inputs; but which in test time
|
||||
will be just the last input.
|
||||
|
||||
The idea is that the list of Tensors will be a list of outputs of multiple
|
||||
conformer layers. This has a similar effect as iterated loss. (See:
|
||||
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
|
||||
NETWORKS).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_inputs: int,
|
||||
final_weight: float = 0.5,
|
||||
pure_prob: float = 0.5,
|
||||
stddev: float = 2.0,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
num_inputs:
|
||||
The number of tensor inputs, which equals the number of layers'
|
||||
outputs that are fed into this module. E.g. in an 18-layer neural
|
||||
net if we output layers 16, 12, 18, num_inputs would be 3.
|
||||
final_weight:
|
||||
The amount of weight or probability we assign to the
|
||||
final layer when randomly choosing layers or when choosing
|
||||
continuous layer weights.
|
||||
pure_prob:
|
||||
The probability, on each frame, with which we choose
|
||||
only a single layer to output (rather than an interpolation)
|
||||
stddev:
|
||||
A standard deviation that we add to log-probs for computing
|
||||
randomized weights.
|
||||
|
||||
The method of choosing which layers, or combinations of layers, to use,
|
||||
is conceptually as follows::
|
||||
|
||||
With probability `pure_prob`::
|
||||
With probability `final_weight`: choose final layer,
|
||||
Else: choose random non-final layer.
|
||||
Else::
|
||||
Choose initial log-weights that correspond to assigning
|
||||
weight `final_weight` to the final layer and equal
|
||||
weights to other layers; then add Gaussian noise
|
||||
with variance `stddev` to these log-weights, and normalize
|
||||
to weights (note: the average weight assigned to the
|
||||
final layer here will not be `final_weight` if stddev>0).
|
||||
"""
|
||||
super().__init__()
|
||||
assert 0 <= pure_prob <= 1, pure_prob
|
||||
assert 0 < final_weight < 1, final_weight
|
||||
assert num_inputs >= 1
|
||||
|
||||
self.num_inputs = num_inputs
|
||||
self.final_weight = final_weight
|
||||
self.pure_prob = pure_prob
|
||||
self.stddev = stddev
|
||||
|
||||
self.final_log_weight = (
|
||||
torch.tensor(
|
||||
(final_weight / (1 - final_weight)) * (self.num_inputs - 1)
|
||||
)
|
||||
.log()
|
||||
.item()
|
||||
)
|
||||
|
||||
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
|
||||
"""Forward function.
|
||||
Args:
|
||||
inputs:
|
||||
A list of Tensor, e.g. from various layers of a transformer.
|
||||
All must be the same shape, of (*, num_channels)
|
||||
Returns:
|
||||
A Tensor of shape (*, num_channels). In test mode
|
||||
this is just the final input.
|
||||
"""
|
||||
num_inputs = self.num_inputs
|
||||
assert len(inputs) == num_inputs
|
||||
if not self.training or torch.jit.is_scripting():
|
||||
return inputs[-1]
|
||||
|
||||
# Shape of weights: (*, num_inputs)
|
||||
num_channels = inputs[0].shape[-1]
|
||||
num_frames = inputs[0].numel() // num_channels
|
||||
|
||||
ndim = inputs[0].ndim
|
||||
# stacked_inputs: (num_frames, num_channels, num_inputs)
|
||||
stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
|
||||
(num_frames, num_channels, num_inputs)
|
||||
)
|
||||
|
||||
# weights: (num_frames, num_inputs)
|
||||
weights = self._get_random_weights(
|
||||
inputs[0].dtype, inputs[0].device, num_frames
|
||||
)
|
||||
|
||||
weights = weights.reshape(num_frames, num_inputs, 1)
|
||||
# ans: (num_frames, num_channels, 1)
|
||||
ans = torch.matmul(stacked_inputs, weights)
|
||||
# ans: (*, num_channels)
|
||||
|
||||
ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,))
|
||||
|
||||
# The following if causes errors for torch script in torch 1.6.0
|
||||
# if __name__ == "__main__":
|
||||
# # for testing only...
|
||||
# print("Weights = ", weights.reshape(num_frames, num_inputs))
|
||||
return ans
|
||||
|
||||
def _get_random_weights(
|
||||
self, dtype: torch.dtype, device: torch.device, num_frames: int
|
||||
) -> torch.Tensor:
|
||||
"""Return a tensor of random weights, of shape
|
||||
`(num_frames, self.num_inputs)`,
|
||||
Args:
|
||||
dtype:
|
||||
The data-type desired for the answer, e.g. float, double.
|
||||
device:
|
||||
The device needed for the answer.
|
||||
num_frames:
|
||||
The number of sets of weights desired
|
||||
Returns:
|
||||
A tensor of shape (num_frames, self.num_inputs), such that
|
||||
`ans.sum(dim=1)` is all ones.
|
||||
"""
|
||||
pure_prob = self.pure_prob
|
||||
if pure_prob == 0.0:
|
||||
return self._get_random_mixed_weights(dtype, device, num_frames)
|
||||
elif pure_prob == 1.0:
|
||||
return self._get_random_pure_weights(dtype, device, num_frames)
|
||||
else:
|
||||
p = self._get_random_pure_weights(dtype, device, num_frames)
|
||||
m = self._get_random_mixed_weights(dtype, device, num_frames)
|
||||
return torch.where(
|
||||
torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m
|
||||
)
|
||||
|
||||
def _get_random_pure_weights(
|
||||
self, dtype: torch.dtype, device: torch.device, num_frames: int
|
||||
):
|
||||
"""Return a tensor of random one-hot weights, of shape
|
||||
`(num_frames, self.num_inputs)`,
|
||||
Args:
|
||||
dtype:
|
||||
The data-type desired for the answer, e.g. float, double.
|
||||
device:
|
||||
The device needed for the answer.
|
||||
num_frames:
|
||||
The number of sets of weights desired.
|
||||
Returns:
|
||||
A one-hot tensor of shape `(num_frames, self.num_inputs)`, with
|
||||
exactly one weight equal to 1.0 on each frame.
|
||||
"""
|
||||
final_prob = self.final_weight
|
||||
|
||||
# final contains self.num_inputs - 1 in all elements
|
||||
final = torch.full((num_frames,), self.num_inputs - 1, device=device)
|
||||
# nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa
|
||||
nonfinal = torch.randint(
|
||||
self.num_inputs - 1, (num_frames,), device=device
|
||||
)
|
||||
|
||||
indexes = torch.where(
|
||||
torch.rand(num_frames, device=device) < final_prob, final, nonfinal
|
||||
)
|
||||
ans = torch.nn.functional.one_hot(
|
||||
indexes, num_classes=self.num_inputs
|
||||
).to(dtype=dtype)
|
||||
return ans
|
||||
|
||||
def _get_random_mixed_weights(
|
||||
self, dtype: torch.dtype, device: torch.device, num_frames: int
|
||||
):
|
||||
"""Return a tensor of random one-hot weights, of shape
|
||||
`(num_frames, self.num_inputs)`,
|
||||
Args:
|
||||
dtype:
|
||||
The data-type desired for the answer, e.g. float, double.
|
||||
device:
|
||||
The device needed for the answer.
|
||||
num_frames:
|
||||
The number of sets of weights desired.
|
||||
Returns:
|
||||
A tensor of shape (num_frames, self.num_inputs), which elements
|
||||
in [0..1] that sum to one over the second axis, i.e.
|
||||
`ans.sum(dim=1)` is all ones.
|
||||
"""
|
||||
logprobs = (
|
||||
torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device)
|
||||
* self.stddev
|
||||
)
|
||||
logprobs[:, -1] += self.final_log_weight
|
||||
return logprobs.softmax(dim=1)
|
||||
|
||||
|
||||
def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
|
||||
print(
|
||||
f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" # noqa
|
||||
)
|
||||
num_inputs = 3
|
||||
num_channels = 50
|
||||
m = RandomCombine(
|
||||
num_inputs=num_inputs,
|
||||
final_weight=final_weight,
|
||||
pure_prob=pure_prob,
|
||||
stddev=stddev,
|
||||
)
|
||||
|
||||
x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
|
||||
|
||||
y = m(x)
|
||||
assert y.shape == x[0].shape
|
||||
assert torch.allclose(y, x[0]) # .. since actually all ones.
|
||||
|
||||
|
||||
def _test_random_combine_main():
|
||||
_test_random_combine(0.999, 0, 0.0)
|
||||
_test_random_combine(0.5, 0, 0.0)
|
||||
_test_random_combine(0.999, 0, 0.0)
|
||||
_test_random_combine(0.5, 0, 0.3)
|
||||
_test_random_combine(0.5, 1, 0.3)
|
||||
_test_random_combine(0.5, 0.5, 0.3)
|
||||
|
||||
feature_dim = 50
|
||||
c = RNN(num_features=feature_dim, d_model=128)
|
||||
batch_size = 5
|
||||
seq_len = 20
|
||||
# Just make sure the forward pass runs.
|
||||
f = c(
|
||||
torch.randn(batch_size, seq_len, feature_dim),
|
||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||
)
|
||||
f # to remove flake8 warnings
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
feature_dim = 80
|
||||
m = RNN(
|
||||
num_features=feature_dim,
|
||||
d_model=512,
|
||||
rnn_hidden_size=1024,
|
||||
dim_feedforward=2048,
|
||||
num_encoder_layers=12,
|
||||
)
|
||||
batch_size = 5
|
||||
seq_len = 20
|
||||
# Just make sure the forward pass runs.
|
||||
f = m(
|
||||
torch.randn(batch_size, seq_len, feature_dim),
|
||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||
warmup=0.5,
|
||||
)
|
||||
num_param = sum([p.numel() for p in m.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
_test_random_combine_main()
|
202
egs/librispeech/ASR/lstm_transducer_stateless/model.py
Normal file
202
egs/librispeech/ASR/lstm_transducer_stateless/model.py
Normal file
@ -0,0 +1,202 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
"Sequence Transduction with Recurrent Neural Networks"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||
It should contain one attribute: `blank_id`.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and
|
||||
(N, U, decoder_dim).
|
||||
Its output shape is (N, T, U, vocab_size). Note that its output
|
||||
contains unnormalized probs, i.e., not processed by log-softmax.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
assert hasattr(decoder, "blank_id")
|
||||
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
self.simple_am_proj = ScaledLinear(
|
||||
encoder_dim, vocab_size, initial_speed=0.5
|
||||
)
|
||||
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
warmup: float = 1.0,
|
||||
reduction: str = "sum",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
warmup:
|
||||
A value warmup >= 0 that determines which modules are active, values
|
||||
warmup > 1 "are fully warmed up" and all modules will be active.
|
||||
reduction:
|
||||
"sum" to sum the losses over all utterances in the batch.
|
||||
"none" to return the loss in a 1-D tensor for each utterance
|
||||
in the batch.
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert reduction in ("sum", "none"), reduction
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
encoder_out, x_lens, _ = self.encoder(x, x_lens, warmup=warmup)
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(x.size(0), 4), dtype=torch.int64, device=x.device
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction=reduction,
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction=reduction,
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
1
egs/librispeech/ASR/lstm_transducer_stateless/optim.py
Symbolic link
1
egs/librispeech/ASR/lstm_transducer_stateless/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/optim.py
|
352
egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py
Executable file
352
egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py
Executable file
@ -0,0 +1,352 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
(1) greedy search
|
||||
./lstm_transducer_stateless/pretrained.py \
|
||||
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method greedy_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) beam search
|
||||
./lstm_transducer_stateless/pretrained.py \
|
||||
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) modified beam search
|
||||
./lstm_transducer_stateless/pretrained.py \
|
||||
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(4) fast beam search
|
||||
./lstm_transducer_stateless/pretrained.py \
|
||||
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method fast_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
You can also use `./lstm_transducer_stateless/exp/epoch-xx.pt`.
|
||||
|
||||
Note: ./lstm_transducer_stateless/exp/pretrained.pt is generated by
|
||||
./lstm_transducer_stateless/export.py
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame. Used only when
|
||||
--method is greedy_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens, _ = model.encoder(
|
||||
x=features, x_lens=feature_lengths
|
||||
)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyps = []
|
||||
msg = f"Using {params.method}"
|
||||
if params.method == "beam_search":
|
||||
msg += f" with beam size {params.beam_size}"
|
||||
logging.info(msg)
|
||||
|
||||
if params.method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
for i in range(num_waves):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/librispeech/ASR/lstm_transducer_stateless/scaling.py
Symbolic link
1
egs/librispeech/ASR/lstm_transducer_stateless/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/scaling.py
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless3/scaling_converter.py
|
148
egs/librispeech/ASR/lstm_transducer_stateless/stream.py
Normal file
148
egs/librispeech/ASR/lstm_transducer_stateless/stream.py
Normal file
@ -0,0 +1,148 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from beam_search import Hypothesis, HypothesisList
|
||||
|
||||
from icefall.utils import AttributeDict
|
||||
|
||||
|
||||
class Stream(object):
|
||||
def __init__(
|
||||
self,
|
||||
params: AttributeDict,
|
||||
cut_id: str,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
LOG_EPS: float = math.log(1e-10),
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
cut_id:
|
||||
The cut id of the current stream.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
device:
|
||||
The device to run this stream.
|
||||
LOG_EPS:
|
||||
A float value used for padding.
|
||||
"""
|
||||
self.LOG_EPS = LOG_EPS
|
||||
self.cut_id = cut_id
|
||||
|
||||
# Containing attention caches and convolution caches
|
||||
self.states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
||||
|
||||
# It uses different attributes for different decoding methods.
|
||||
self.context_size = params.context_size
|
||||
self.decoding_method = params.decoding_method
|
||||
if params.decoding_method == "greedy_search":
|
||||
self.hyp = [params.blank_id] * params.context_size
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
self.hyps = HypothesisList()
|
||||
self.hyps.add(
|
||||
Hypothesis(
|
||||
ys=[params.blank_id] * params.context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
# feature_len is needed to get partial results.
|
||||
# The rnnt_decoding_stream for fast_beam_search.
|
||||
self.rnnt_decoding_stream: k2.RnntDecodingStream = (
|
||||
k2.RnntDecodingStream(decoding_graph)
|
||||
)
|
||||
self.hyp: Optional[List[int]] = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
|
||||
self.ground_truth: str = ""
|
||||
|
||||
self.feature: Optional[torch.Tensor] = None
|
||||
# Make sure all feature frames can be used.
|
||||
# We aim to obtain 1 frame after subsampling.
|
||||
self.chunk_length = params.subsampling_factor
|
||||
self.pad_length = 5
|
||||
self.num_frames = 0
|
||||
self.num_processed_frames = 0
|
||||
|
||||
# After all feature frames are processed, we set this flag to True
|
||||
self._done = False
|
||||
|
||||
def set_feature(self, feature: torch.Tensor) -> None:
|
||||
assert feature.dim() == 2, feature.dim()
|
||||
# tail padding here to alleviate the tail deletion problem
|
||||
num_tail_padded_frames = 35
|
||||
self.num_frames = feature.size(0) + num_tail_padded_frames
|
||||
self.feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
(0, 0, 0, self.pad_length + num_tail_padded_frames),
|
||||
mode="constant",
|
||||
value=self.LOG_EPS,
|
||||
)
|
||||
|
||||
def get_feature_chunk(self) -> torch.Tensor:
|
||||
"""Get a chunk of feature frames.
|
||||
|
||||
Returns:
|
||||
A tensor of shape (ret_length, feature_dim).
|
||||
"""
|
||||
update_length = min(
|
||||
self.num_frames - self.num_processed_frames, self.chunk_length
|
||||
)
|
||||
ret_length = update_length + self.pad_length
|
||||
|
||||
ret_feature = self.feature[
|
||||
self.num_processed_frames : self.num_processed_frames + ret_length
|
||||
]
|
||||
# Cut off used frames.
|
||||
# self.feature = self.feature[update_length:]
|
||||
|
||||
self.num_processed_frames += update_length
|
||||
if self.num_processed_frames >= self.num_frames:
|
||||
self._done = True
|
||||
|
||||
return ret_feature
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.cut_id
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
"""Return True if all feature frames are processed."""
|
||||
return self._done
|
||||
|
||||
def decoding_result(self) -> List[int]:
|
||||
"""Obtain current decoding result."""
|
||||
if self.decoding_method == "greedy_search":
|
||||
return self.hyp[self.context_size :]
|
||||
elif self.decoding_method == "modified_beam_search":
|
||||
best_hyp = self.hyps.get_most_probable(length_norm=True)
|
||||
return best_hyp.ys[self.context_size :]
|
||||
else:
|
||||
assert self.decoding_method == "fast_beam_search"
|
||||
return self.hyp
|
968
egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py
Executable file
968
egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py
Executable file
@ -0,0 +1,968 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./lstm_transducer_stateless/streaming_decode.py \
|
||||
--epoch 35 \
|
||||
--avg 10 \
|
||||
--exp-dir lstm_transducer_stateless/exp \
|
||||
--num-decode-streams 2000 \
|
||||
--num-encoder-layers 12 \
|
||||
--rnn-hidden-size 1024 \
|
||||
--decoding-method greedy_search \
|
||||
--use-averaged-model True
|
||||
|
||||
(2) modified beam search
|
||||
./lstm_transducer_stateless/streaming_decode.py \
|
||||
--epoch 35 \
|
||||
--avg 10 \
|
||||
--exp-dir lstm_transducer_stateless/exp \
|
||||
--num-decode-streams 2000 \
|
||||
--num-encoder-layers 12 \
|
||||
--rnn-hidden-size 1024 \
|
||||
--decoding-method modified_beam_search \
|
||||
--use-averaged-model True \
|
||||
--beam-size 4
|
||||
|
||||
(3) fast beam search
|
||||
./lstm_transducer_stateless/streaming_decode.py \
|
||||
--epoch 35 \
|
||||
--avg 10 \
|
||||
--exp-dir lstm_transducer_stateless/exp \
|
||||
--num-decode-streams 2000 \
|
||||
--num-encoder-layers 12 \
|
||||
--rnn-hidden-size 1024 \
|
||||
--decoding-method fast_beam_search \
|
||||
--use-averaged-model True \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from lstm import LOG_EPSILON, stack_states, unstack_states
|
||||
from stream import Stream
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.decode import one_best_decoding
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="transducer_emformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An interger indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sampling-rate",
|
||||
type=float,
|
||||
default=16000,
|
||||
help="Sample rate of the audio",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-decode-streams",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="The number of streams that can be decoded in parallel",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
streams: List[Stream],
|
||||
) -> None:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
||||
streams:
|
||||
A list of Stream objects.
|
||||
"""
|
||||
assert len(streams) == encoder_out.size(0)
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
device = next(model.parameters()).device
|
||||
T = encoder_out.size(1)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[stream.hyp[-context_size:] for stream in streams],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
# decoder_out is of shape (batch_size, 1, decoder_out_dim)
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
for t in range(T):
|
||||
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
|
||||
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out.unsqueeze(2),
|
||||
decoder_out.unsqueeze(1),
|
||||
project_input=False,
|
||||
)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
streams[i].hyp.append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = torch.tensor(
|
||||
[stream.hyp[-context_size:] for stream in streams],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=False,
|
||||
)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
|
||||
def modified_beam_search(
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
streams: List[Stream],
|
||||
beam: int = 4,
|
||||
):
|
||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The RNN-T model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
|
||||
the encoder model.
|
||||
streams:
|
||||
A list of stream objects.
|
||||
beam:
|
||||
Number of active paths during the beam search.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert len(streams) == encoder_out.size(0)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
device = next(model.parameters()).device
|
||||
batch_size = len(streams)
|
||||
T = encoder_out.size(1)
|
||||
|
||||
B = [stream.hyps for stream in streams]
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
for t in range(T):
|
||||
current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1)
|
||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||
|
||||
hyps_shape = get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
||||
ys_log_probs = torch.stack(
|
||||
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
|
||||
) # (num_hyps, 1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (num_hyps, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
# decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim)
|
||||
|
||||
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||
# as index, so we use `to(torch.int64)` below.
|
||||
current_encoder_out = torch.index_select(
|
||||
current_encoder_out,
|
||||
dim=0,
|
||||
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||
) # (num_hyps, encoder_out_dim)
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out, project_input=False
|
||||
)
|
||||
# logits is of shape (num_hyps, 1, 1, vocab_size)
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
|
||||
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
||||
vocab_size = log_probs.size(-1)
|
||||
|
||||
log_probs = log_probs.reshape(-1)
|
||||
|
||||
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||
)
|
||||
ragged_log_probs = k2.RaggedTensor(
|
||||
shape=log_probs_shape, value=log_probs
|
||||
)
|
||||
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[i][hyp_idx]
|
||||
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[k]
|
||||
if new_token != blank_id:
|
||||
new_ys.append(new_token)
|
||||
|
||||
new_log_prob = topk_log_probs[k]
|
||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
for i in range(batch_size):
|
||||
streams[i].hyps = B[i]
|
||||
|
||||
|
||||
def fast_beam_search_one_best(
|
||||
model: nn.Module,
|
||||
streams: List[Stream],
|
||||
encoder_out: torch.Tensor,
|
||||
processed_lens: torch.Tensor,
|
||||
beam: float,
|
||||
max_states: int,
|
||||
max_contexts: int,
|
||||
) -> None:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
A lattice is first obtained using modified beam search, and then
|
||||
the shortest path within the lattice is used as the final output.
|
||||
|
||||
Args:
|
||||
model:
|
||||
An instance of `Transducer`.
|
||||
streams:
|
||||
A list of stream objects.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder.
|
||||
processed_lens:
|
||||
A tensor of shape (N,) containing the number of processed frames
|
||||
in `encoder_out` before padding.
|
||||
beam:
|
||||
Beam value, similar to the beam used in Kaldi..
|
||||
max_states:
|
||||
Max states per stream per frame.
|
||||
max_contexts:
|
||||
Max contexts pre stream per frame.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
context_size = model.decoder.context_size
|
||||
vocab_size = model.decoder.vocab_size
|
||||
|
||||
B, T, C = encoder_out.shape
|
||||
assert B == len(streams)
|
||||
|
||||
config = k2.RnntDecodingConfig(
|
||||
vocab_size=vocab_size,
|
||||
decoder_history_len=context_size,
|
||||
beam=beam,
|
||||
max_contexts=max_contexts,
|
||||
max_states=max_states,
|
||||
)
|
||||
individual_streams = []
|
||||
for i in range(B):
|
||||
individual_streams.append(streams[i].rnnt_decoding_stream)
|
||||
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
for t in range(T):
|
||||
# shape is a RaggedShape of shape (B, context)
|
||||
# contexts is a Tensor of shape (shape.NumElements(), context_size)
|
||||
shape, contexts = decoding_streams.get_contexts()
|
||||
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
|
||||
contexts = contexts.to(torch.int64)
|
||||
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
|
||||
decoder_out = model.decoder(contexts, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
# current_encoder_out is of shape
|
||||
# (shape.NumElements(), 1, joiner_dim)
|
||||
# fmt: off
|
||||
current_encoder_out = torch.index_select(
|
||||
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
|
||||
)
|
||||
# fmt: on
|
||||
logits = model.joiner(
|
||||
current_encoder_out.unsqueeze(2),
|
||||
decoder_out.unsqueeze(1),
|
||||
project_input=False,
|
||||
)
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
log_probs = logits.log_softmax(dim=-1)
|
||||
decoding_streams.advance(log_probs)
|
||||
|
||||
decoding_streams.terminate_and_flush_to_streams()
|
||||
|
||||
lattice = decoding_streams.format_output(processed_lens.tolist())
|
||||
|
||||
best_path = one_best_decoding(lattice)
|
||||
hyps = get_texts(best_path)
|
||||
|
||||
for i in range(B):
|
||||
streams[i].hyp = hyps[i]
|
||||
|
||||
|
||||
def decode_one_chunk(
|
||||
model: nn.Module,
|
||||
streams: List[Stream],
|
||||
params: AttributeDict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Args:
|
||||
model:
|
||||
The Transducer model.
|
||||
streams:
|
||||
A list of Stream objects.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
|
||||
Returns:
|
||||
A list of indexes indicating the finished streams.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
|
||||
feature_list = []
|
||||
feature_len_list = []
|
||||
state_list = []
|
||||
num_processed_frames_list = []
|
||||
|
||||
for stream in streams:
|
||||
# We should first get `stream.num_processed_frames`
|
||||
# before calling `stream.get_feature_chunk()`
|
||||
# since `stream.num_processed_frames` would be updated
|
||||
num_processed_frames_list.append(stream.num_processed_frames)
|
||||
feature = stream.get_feature_chunk()
|
||||
feature_len = feature.size(0)
|
||||
feature_list.append(feature)
|
||||
feature_len_list.append(feature_len)
|
||||
state_list.append(stream.states)
|
||||
|
||||
features = pad_sequence(
|
||||
feature_list, batch_first=True, padding_value=LOG_EPSILON
|
||||
).to(device)
|
||||
feature_lens = torch.tensor(feature_len_list, device=device)
|
||||
num_processed_frames = torch.tensor(
|
||||
num_processed_frames_list, device=device
|
||||
)
|
||||
|
||||
# Make sure it has at least 1 frame after subsampling
|
||||
tail_length = params.subsampling_factor + 5
|
||||
if features.size(1) < tail_length:
|
||||
pad_length = tail_length - features.size(1)
|
||||
feature_lens += pad_length
|
||||
features = torch.nn.functional.pad(
|
||||
features,
|
||||
(0, 0, 0, pad_length),
|
||||
mode="constant",
|
||||
value=LOG_EPSILON,
|
||||
)
|
||||
|
||||
# Stack states of all streams
|
||||
states = stack_states(state_list)
|
||||
|
||||
encoder_out, encoder_out_lens, states = model.encoder(
|
||||
x=features,
|
||||
x_lens=feature_lens,
|
||||
states=states,
|
||||
)
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
greedy_search(
|
||||
model=model,
|
||||
streams=streams,
|
||||
encoder_out=encoder_out,
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
modified_beam_search(
|
||||
model=model,
|
||||
streams=streams,
|
||||
encoder_out=encoder_out,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
# feature_len is needed to get partial results.
|
||||
# The rnnt_decoding_stream for fast_beam_search.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
processed_lens = (
|
||||
num_processed_frames // params.subsampling_factor
|
||||
+ encoder_out_lens
|
||||
)
|
||||
fast_beam_search_one_best(
|
||||
model=model,
|
||||
streams=streams,
|
||||
encoder_out=encoder_out,
|
||||
processed_lens=processed_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
|
||||
# Update cached states of each stream
|
||||
state_list = unstack_states(states)
|
||||
for i, s in enumerate(state_list):
|
||||
streams[i].states = s
|
||||
|
||||
finished_streams = [i for i, stream in enumerate(streams) if stream.done]
|
||||
return finished_streams
|
||||
|
||||
|
||||
def create_streaming_feature_extractor() -> Fbank:
|
||||
"""Create a CPU streaming feature extractor.
|
||||
|
||||
At present, we assume it returns a fbank feature extractor with
|
||||
fixed options. In the future, we will support passing in the options
|
||||
from outside.
|
||||
|
||||
Returns:
|
||||
Return a CPU streaming feature extractor.
|
||||
"""
|
||||
opts = FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
return Fbank(opts)
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
cuts: CutSet,
|
||||
model: nn.Module,
|
||||
params: AttributeDict,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
):
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
cuts:
|
||||
Lhotse Cutset containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The Transducer model.
|
||||
sp:
|
||||
The BPE model.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
|
||||
log_interval = 300
|
||||
|
||||
fbank = create_streaming_feature_extractor()
|
||||
|
||||
decode_results = []
|
||||
streams = []
|
||||
for num, cut in enumerate(cuts):
|
||||
# Each utterance has a Stream.
|
||||
stream = Stream(
|
||||
params=params,
|
||||
cut_id=cut.id,
|
||||
decoding_graph=decoding_graph,
|
||||
device=device,
|
||||
LOG_EPS=LOG_EPSILON,
|
||||
)
|
||||
|
||||
stream.states = model.encoder.get_init_states(device=device)
|
||||
|
||||
audio: np.ndarray = cut.load_audio()
|
||||
# audio.shape: (1, num_samples)
|
||||
assert len(audio.shape) == 2
|
||||
assert audio.shape[0] == 1, "Should be single channel"
|
||||
assert audio.dtype == np.float32, audio.dtype
|
||||
# The trained model is using normalized samples
|
||||
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
|
||||
|
||||
samples = torch.from_numpy(audio).squeeze(0)
|
||||
feature = fbank(samples)
|
||||
stream.set_feature(feature)
|
||||
stream.ground_truth = cut.supervisions[0].text
|
||||
|
||||
streams.append(stream)
|
||||
|
||||
while len(streams) >= params.num_decode_streams:
|
||||
finished_streams = decode_one_chunk(
|
||||
model=model,
|
||||
streams=streams,
|
||||
params=params,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
(
|
||||
streams[i].id,
|
||||
streams[i].ground_truth.split(),
|
||||
sp.decode(streams[i].decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
del streams[i]
|
||||
|
||||
if num % log_interval == 0:
|
||||
logging.info(f"Cuts processed until now is {num}.")
|
||||
|
||||
while len(streams) > 0:
|
||||
finished_streams = decode_one_chunk(
|
||||
model=model,
|
||||
streams=streams,
|
||||
params=params,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
(
|
||||
streams[i].id,
|
||||
streams[i].ground_truth.split(),
|
||||
sp.decode(streams[i].decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
del streams[i]
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
key = "greedy_search"
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
key = (
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
)
|
||||
else:
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
|
||||
return {key: decode_results}
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
store_transcripts(filename=recog_path, texts=sorted(results))
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"fast_beam_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-streaming-decode")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
params.device = device
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.eval()
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_cuts = [test_clean_cuts, test_other_cuts]
|
||||
|
||||
for test_set, test_cut in zip(test_sets, test_cuts):
|
||||
results_dict = decode_dataset(
|
||||
cuts=test_cut,
|
||||
model=model,
|
||||
params=params,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20220810)
|
||||
main()
|
92
egs/librispeech/ASR/lstm_transducer_stateless/test_model.py
Executable file
92
egs/librispeech/ASR/lstm_transducer_stateless/test_model.py
Executable file
@ -0,0 +1,92 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./lstm_transducer_stateless/test_model.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from export import (
|
||||
export_decoder_model_jit_trace,
|
||||
export_encoder_model_jit_trace,
|
||||
export_joiner_model_jit_trace,
|
||||
)
|
||||
from lstm import stack_states, unstack_states
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def test_model():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.unk_id = 2
|
||||
params.encoder_dim = 512
|
||||
params.rnn_hidden_size = 1024
|
||||
params.num_encoder_layers = 12
|
||||
params.aux_layer_period = 0
|
||||
params.exp_dir = Path("exp_test_model")
|
||||
|
||||
model = get_transducer_model(params)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
|
||||
if not os.path.exists(params.exp_dir):
|
||||
os.path.mkdir(params.exp_dir)
|
||||
|
||||
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
||||
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||
|
||||
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
|
||||
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||
|
||||
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
|
||||
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||
|
||||
print("The model has been successfully exported using jit.trace.")
|
||||
|
||||
|
||||
def test_states_stack_and_unstack():
|
||||
layer, batch, hidden, cell = 12, 100, 512, 1024
|
||||
states = (
|
||||
torch.randn(layer, batch, hidden),
|
||||
torch.randn(layer, batch, cell),
|
||||
)
|
||||
states2 = stack_states(unstack_states(states))
|
||||
assert torch.allclose(states[0], states2[0])
|
||||
assert torch.allclose(states[1], states2[1])
|
||||
|
||||
|
||||
def main():
|
||||
test_model()
|
||||
test_states_stack_and_unstack()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,257 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./lstm_transducer_stateless/test_scaling_converter.py
|
||||
"""
|
||||
|
||||
import copy
|
||||
|
||||
import torch
|
||||
from scaling import (
|
||||
ScaledConv1d,
|
||||
ScaledConv2d,
|
||||
ScaledEmbedding,
|
||||
ScaledLinear,
|
||||
ScaledLSTM,
|
||||
)
|
||||
from scaling_converter import (
|
||||
convert_scaled_to_non_scaled,
|
||||
scaled_conv1d_to_conv1d,
|
||||
scaled_conv2d_to_conv2d,
|
||||
scaled_embedding_to_embedding,
|
||||
scaled_linear_to_linear,
|
||||
scaled_lstm_to_lstm,
|
||||
)
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_model():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.unk_id = 2
|
||||
params.encoder_dim = 512
|
||||
params.rnn_hidden_size = 1024
|
||||
params.num_encoder_layers = 12
|
||||
params.aux_layer_period = -1
|
||||
|
||||
model = get_transducer_model(params)
|
||||
return model
|
||||
|
||||
|
||||
def test_scaled_linear_to_linear():
|
||||
N = 5
|
||||
in_features = 10
|
||||
out_features = 20
|
||||
for bias in [True, False]:
|
||||
scaled_linear = ScaledLinear(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
)
|
||||
linear = scaled_linear_to_linear(scaled_linear)
|
||||
x = torch.rand(N, in_features)
|
||||
|
||||
y1 = scaled_linear(x)
|
||||
y2 = linear(x)
|
||||
assert torch.allclose(y1, y2)
|
||||
|
||||
jit_scaled_linear = torch.jit.script(scaled_linear)
|
||||
jit_linear = torch.jit.script(linear)
|
||||
|
||||
y3 = jit_scaled_linear(x)
|
||||
y4 = jit_linear(x)
|
||||
|
||||
assert torch.allclose(y3, y4)
|
||||
assert torch.allclose(y1, y4)
|
||||
|
||||
|
||||
def test_scaled_conv1d_to_conv1d():
|
||||
in_channels = 3
|
||||
for bias in [True, False]:
|
||||
scaled_conv1d = ScaledConv1d(
|
||||
in_channels,
|
||||
6,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
conv1d = scaled_conv1d_to_conv1d(scaled_conv1d)
|
||||
|
||||
x = torch.rand(20, in_channels, 10)
|
||||
y1 = scaled_conv1d(x)
|
||||
y2 = conv1d(x)
|
||||
assert torch.allclose(y1, y2)
|
||||
|
||||
jit_scaled_conv1d = torch.jit.script(scaled_conv1d)
|
||||
jit_conv1d = torch.jit.script(conv1d)
|
||||
|
||||
y3 = jit_scaled_conv1d(x)
|
||||
y4 = jit_conv1d(x)
|
||||
|
||||
assert torch.allclose(y3, y4)
|
||||
assert torch.allclose(y1, y4)
|
||||
|
||||
|
||||
def test_scaled_conv2d_to_conv2d():
|
||||
in_channels = 1
|
||||
for bias in [True, False]:
|
||||
scaled_conv2d = ScaledConv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=3,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
conv2d = scaled_conv2d_to_conv2d(scaled_conv2d)
|
||||
|
||||
x = torch.rand(20, in_channels, 10, 20)
|
||||
y1 = scaled_conv2d(x)
|
||||
y2 = conv2d(x)
|
||||
assert torch.allclose(y1, y2)
|
||||
|
||||
jit_scaled_conv2d = torch.jit.script(scaled_conv2d)
|
||||
jit_conv2d = torch.jit.script(conv2d)
|
||||
|
||||
y3 = jit_scaled_conv2d(x)
|
||||
y4 = jit_conv2d(x)
|
||||
|
||||
assert torch.allclose(y3, y4)
|
||||
assert torch.allclose(y1, y4)
|
||||
|
||||
|
||||
def test_scaled_embedding_to_embedding():
|
||||
scaled_embedding = ScaledEmbedding(
|
||||
num_embeddings=500,
|
||||
embedding_dim=10,
|
||||
padding_idx=0,
|
||||
)
|
||||
embedding = scaled_embedding_to_embedding(scaled_embedding)
|
||||
|
||||
for s in [10, 100, 300, 500, 800, 1000]:
|
||||
x = torch.randint(low=0, high=500, size=(s,))
|
||||
scaled_y = scaled_embedding(x)
|
||||
y = embedding(x)
|
||||
assert torch.equal(scaled_y, y)
|
||||
|
||||
|
||||
def test_scaled_lstm_to_lstm():
|
||||
input_size = 512
|
||||
batch_size = 20
|
||||
for bias in [True, False]:
|
||||
for hidden_size in [512, 1024]:
|
||||
scaled_lstm = ScaledLSTM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=1,
|
||||
bias=bias,
|
||||
proj_size=0 if hidden_size == input_size else input_size,
|
||||
)
|
||||
|
||||
lstm = scaled_lstm_to_lstm(scaled_lstm)
|
||||
|
||||
x = torch.rand(200, batch_size, input_size)
|
||||
h0 = torch.randn(1, batch_size, input_size)
|
||||
c0 = torch.randn(1, batch_size, hidden_size)
|
||||
|
||||
y1, (h1, c1) = scaled_lstm(x, (h0, c0))
|
||||
y2, (h2, c2) = lstm(x, (h0, c0))
|
||||
assert torch.allclose(y1, y2)
|
||||
assert torch.allclose(h1, h2)
|
||||
assert torch.allclose(c1, c2)
|
||||
|
||||
jit_scaled_lstm = torch.jit.trace(lstm, (x, (h0, c0)))
|
||||
y3, (h3, c3) = jit_scaled_lstm(x, (h0, c0))
|
||||
assert torch.allclose(y1, y3)
|
||||
assert torch.allclose(h1, h3)
|
||||
assert torch.allclose(c1, c3)
|
||||
|
||||
|
||||
def test_convert_scaled_to_non_scaled():
|
||||
for inplace in [False, True]:
|
||||
model = get_model()
|
||||
model.eval()
|
||||
|
||||
orig_model = copy.deepcopy(model)
|
||||
|
||||
converted_model = convert_scaled_to_non_scaled(model, inplace=inplace)
|
||||
|
||||
model = orig_model
|
||||
|
||||
# test encoder
|
||||
N = 2
|
||||
T = 100
|
||||
vocab_size = model.decoder.vocab_size
|
||||
|
||||
x = torch.randn(N, T, 80, dtype=torch.float32)
|
||||
x_lens = torch.full((N,), x.size(1))
|
||||
|
||||
e1, e1_lens, _ = model.encoder(x, x_lens)
|
||||
e2, e2_lens, _ = converted_model.encoder(x, x_lens)
|
||||
|
||||
assert torch.all(torch.eq(e1_lens, e2_lens))
|
||||
assert torch.allclose(e1, e2), (e1 - e2).abs().max()
|
||||
|
||||
# test decoder
|
||||
U = 50
|
||||
y = torch.randint(low=1, high=vocab_size - 1, size=(N, U))
|
||||
|
||||
d1 = model.decoder(y)
|
||||
d2 = model.decoder(y)
|
||||
|
||||
assert torch.allclose(d1, d2)
|
||||
|
||||
# test simple projection
|
||||
lm1 = model.simple_lm_proj(d1)
|
||||
am1 = model.simple_am_proj(e1)
|
||||
|
||||
lm2 = converted_model.simple_lm_proj(d2)
|
||||
am2 = converted_model.simple_am_proj(e2)
|
||||
|
||||
assert torch.allclose(lm1, lm2)
|
||||
assert torch.allclose(am1, am2)
|
||||
|
||||
# test joiner
|
||||
e = torch.rand(2, 3, 4, 512)
|
||||
d = torch.rand(2, 3, 4, 512)
|
||||
|
||||
j1 = model.joiner(e, d)
|
||||
j2 = converted_model.joiner(e, d)
|
||||
assert torch.allclose(j1, j2)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
test_scaled_linear_to_linear()
|
||||
test_scaled_conv1d_to_conv1d()
|
||||
test_scaled_conv2d_to_conv2d()
|
||||
test_scaled_embedding_to_embedding()
|
||||
test_scaled_lstm_to_lstm()
|
||||
test_convert_scaled_to_non_scaled()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20220730)
|
||||
main()
|
1119
egs/librispeech/ASR/lstm_transducer_stateless/train.py
Executable file
1119
egs/librispeech/ASR/lstm_transducer_stateless/train.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -391,6 +391,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -403,9 +404,9 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -430,6 +431,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -612,6 +614,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
|
@ -603,6 +603,15 @@ def compute_loss(
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
||||
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
|
||||
info["utterances"] = feature.size(0)
|
||||
# averaged input duration in frames over utterances
|
||||
info["utt_duration"] = feature_lens.sum().item()
|
||||
# averaged padding proportion over utterances
|
||||
info["utt_pad_proportion"] = (
|
||||
((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||
|
@ -751,7 +751,7 @@ class HypothesisList(object):
|
||||
return ", ".join(s)
|
||||
|
||||
|
||||
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
||||
def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
||||
"""Return a ragged shape with axes [utt][num_hyps].
|
||||
|
||||
Args:
|
||||
@ -847,7 +847,7 @@ def modified_beam_search(
|
||||
finalized_B = B[batch_size:] + finalized_B
|
||||
B = B[:batch_size]
|
||||
|
||||
hyps_shape = _get_hyps_shape(B).to(device)
|
||||
hyps_shape = get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
@ -551,6 +551,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -564,9 +565,9 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -591,6 +592,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -631,6 +633,8 @@ def main():
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
@ -754,6 +758,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
|
@ -19,6 +19,7 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from beam_search import Hypothesis, HypothesisList
|
||||
|
||||
from icefall.utils import AttributeDict
|
||||
|
||||
@ -27,6 +28,7 @@ class DecodeStream(object):
|
||||
def __init__(
|
||||
self,
|
||||
params: AttributeDict,
|
||||
cut_id: str,
|
||||
initial_states: List[torch.Tensor],
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
@ -42,10 +44,12 @@ class DecodeStream(object):
|
||||
device:
|
||||
The device to run this stream.
|
||||
"""
|
||||
if decoding_graph is not None:
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
assert decoding_graph is not None
|
||||
assert device == decoding_graph.device
|
||||
|
||||
self.params = params
|
||||
self.cut_id = cut_id
|
||||
self.LOG_EPS = math.log(1e-10)
|
||||
|
||||
self.states = initial_states
|
||||
@ -77,21 +81,33 @@ class DecodeStream(object):
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
self.hyp = [params.blank_id] * params.context_size
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
self.hyps = HypothesisList()
|
||||
self.hyps.add(
|
||||
Hypothesis(
|
||||
ys=[params.blank_id] * params.context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
# The rnnt_decoding_stream for fast_beam_search.
|
||||
self.rnnt_decoding_stream: k2.RnntDecodingStream = (
|
||||
k2.RnntDecodingStream(decoding_graph)
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
False
|
||||
), f"Decoding method :{params.decoding_method} do not support."
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
"""Return True if all the features are processed."""
|
||||
return self._done
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.cut_id
|
||||
|
||||
def set_features(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
@ -124,3 +140,14 @@ class DecodeStream(object):
|
||||
self._done = True
|
||||
|
||||
return ret_features, ret_length
|
||||
|
||||
def decoding_result(self) -> List[int]:
|
||||
"""Obtain current decoding result."""
|
||||
if self.params.decoding_method == "greedy_search":
|
||||
return self.hyp[self.params.context_size :] # noqa
|
||||
elif self.params.decoding_method == "modified_beam_search":
|
||||
best_hyp = self.hyps.get_most_probable(length_norm=True)
|
||||
return best_hyp.ys[self.params.context_size :] # noqa
|
||||
else:
|
||||
assert self.params.decoding_method == "fast_beam_search"
|
||||
return self.hyp
|
||||
|
@ -15,6 +15,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -66,7 +68,8 @@ class Transducer(nn.Module):
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
reduction: str = "sum",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
@ -86,6 +89,10 @@ class Transducer(nn.Module):
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
reduction:
|
||||
"sum" to sum the losses over all utterances in the batch.
|
||||
"none" to return the loss in a 1-D tensor for each utterance
|
||||
in the batch.
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
@ -95,6 +102,7 @@ class Transducer(nn.Module):
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert reduction in ("sum", "none"), reduction
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
@ -136,7 +144,7 @@ class Transducer(nn.Module):
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
reduction=reduction,
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
@ -163,7 +171,7 @@ class Transducer(nn.Module):
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
reduction=reduction,
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
|
@ -0,0 +1,280 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from decode_stream import DecodeStream
|
||||
|
||||
from icefall.decode import one_best_decoding
|
||||
from icefall.utils import get_texts
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
streams: List[DecodeStream],
|
||||
) -> None:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
||||
streams:
|
||||
A list of Stream objects.
|
||||
"""
|
||||
assert len(streams) == encoder_out.size(0)
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
device = model.device
|
||||
T = encoder_out.size(1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[stream.hyp[-context_size:] for stream in streams],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
# decoder_out is of shape (N, 1, decoder_out_dim)
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
|
||||
for t in range(T):
|
||||
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
|
||||
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out.unsqueeze(2),
|
||||
decoder_out.unsqueeze(1),
|
||||
)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
streams[i].hyp.append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = torch.tensor(
|
||||
[stream.hyp[-context_size:] for stream in streams],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=False,
|
||||
)
|
||||
|
||||
|
||||
def modified_beam_search(
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
streams: List[DecodeStream],
|
||||
num_active_paths: int = 4,
|
||||
) -> None:
|
||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The RNN-T model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
|
||||
the encoder model.
|
||||
streams:
|
||||
A list of stream objects.
|
||||
num_active_paths:
|
||||
Number of active paths during the beam search.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert len(streams) == encoder_out.size(0)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
device = next(model.parameters()).device
|
||||
batch_size = len(streams)
|
||||
T = encoder_out.size(1)
|
||||
|
||||
B = [stream.hyps for stream in streams]
|
||||
|
||||
for t in range(T):
|
||||
current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1)
|
||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||
|
||||
hyps_shape = get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
||||
ys_log_probs = torch.stack(
|
||||
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
|
||||
) # (num_hyps, 1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (num_hyps, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||
# decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim)
|
||||
|
||||
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||
# as index, so we use `to(torch.int64)` below.
|
||||
current_encoder_out = torch.index_select(
|
||||
current_encoder_out,
|
||||
dim=0,
|
||||
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||
) # (num_hyps, encoder_out_dim)
|
||||
|
||||
logits = model.joiner(current_encoder_out, decoder_out)
|
||||
# logits is of shape (num_hyps, 1, 1, vocab_size)
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
|
||||
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
||||
vocab_size = log_probs.size(-1)
|
||||
|
||||
log_probs = log_probs.reshape(-1)
|
||||
|
||||
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||
)
|
||||
ragged_log_probs = k2.RaggedTensor(
|
||||
shape=log_probs_shape, value=log_probs
|
||||
)
|
||||
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
|
||||
num_active_paths
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[i][hyp_idx]
|
||||
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[k]
|
||||
if new_token != blank_id:
|
||||
new_ys.append(new_token)
|
||||
|
||||
new_log_prob = topk_log_probs[k]
|
||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
for i in range(batch_size):
|
||||
streams[i].hyps = B[i]
|
||||
|
||||
|
||||
def fast_beam_search_one_best(
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
processed_lens: torch.Tensor,
|
||||
streams: List[DecodeStream],
|
||||
beam: float,
|
||||
max_states: int,
|
||||
max_contexts: int,
|
||||
) -> None:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
A lattice is first generated by Fsa-based beam search, then we get the
|
||||
recognition by applying shortest path on the lattice.
|
||||
|
||||
Args:
|
||||
model:
|
||||
An instance of `Transducer`.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder.
|
||||
processed_lens:
|
||||
A tensor of shape (N,) containing the number of processed frames
|
||||
in `encoder_out` before padding.
|
||||
streams:
|
||||
A list of stream objects.
|
||||
beam:
|
||||
Beam value, similar to the beam used in Kaldi..
|
||||
max_states:
|
||||
Max states per stream per frame.
|
||||
max_contexts:
|
||||
Max contexts pre stream per frame.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
B, T, C = encoder_out.shape
|
||||
assert B == len(streams)
|
||||
|
||||
context_size = model.decoder.context_size
|
||||
vocab_size = model.decoder.vocab_size
|
||||
|
||||
config = k2.RnntDecodingConfig(
|
||||
vocab_size=vocab_size,
|
||||
decoder_history_len=context_size,
|
||||
beam=beam,
|
||||
max_contexts=max_contexts,
|
||||
max_states=max_states,
|
||||
)
|
||||
individual_streams = []
|
||||
for i in range(B):
|
||||
individual_streams.append(streams[i].rnnt_decoding_stream)
|
||||
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
|
||||
|
||||
for t in range(T):
|
||||
# shape is a RaggedShape of shape (B, context)
|
||||
# contexts is a Tensor of shape (shape.NumElements(), context_size)
|
||||
shape, contexts = decoding_streams.get_contexts()
|
||||
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
|
||||
contexts = contexts.to(torch.int64)
|
||||
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
|
||||
decoder_out = model.decoder(contexts, need_pad=False)
|
||||
# current_encoder_out is of shape
|
||||
# (shape.NumElements(), 1, joiner_dim)
|
||||
# fmt: off
|
||||
current_encoder_out = torch.index_select(
|
||||
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
|
||||
)
|
||||
# fmt: on
|
||||
logits = model.joiner(
|
||||
current_encoder_out.unsqueeze(2),
|
||||
decoder_out.unsqueeze(1),
|
||||
)
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
log_probs = logits.log_softmax(dim=-1)
|
||||
decoding_streams.advance(log_probs)
|
||||
|
||||
decoding_streams.terminate_and_flush_to_streams()
|
||||
|
||||
lattice = decoding_streams.format_output(processed_lens.tolist())
|
||||
best_path = one_best_decoding(lattice)
|
||||
hyp_tokens = get_texts(best_path)
|
||||
|
||||
for i in range(B):
|
||||
streams[i].hyp = hyp_tokens[i]
|
@ -17,13 +17,13 @@
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./pruned_transducer_stateless2/streaming_decode.py \
|
||||
./pruned_transducer_stateless/streaming_decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--decode-chunk-size 8 \
|
||||
--left-context 32 \
|
||||
--right-context 0 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--exp-dir ./pruned_transducer_stateless/exp \
|
||||
--decoding_method greedy_search \
|
||||
--num-decode-streams 1000
|
||||
"""
|
||||
@ -43,6 +43,11 @@ from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from streaming_beam_search import (
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
@ -51,10 +56,8 @@ from icefall.checkpoint import (
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.decode import one_best_decoding
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
@ -114,10 +117,21 @@ def get_parser():
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Support only greedy_search and fast_beam_search now.
|
||||
help="""Supported decoding methods are:
|
||||
greedy_search
|
||||
modified_beam_search
|
||||
fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-active-paths",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An interger indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
@ -185,103 +199,6 @@ def get_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
streams: List[DecodeStream],
|
||||
) -> List[List[int]]:
|
||||
|
||||
assert len(streams) == encoder_out.size(0)
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
device = model.device
|
||||
T = encoder_out.size(1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[stream.hyp[-context_size:] for stream in streams],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
# decoder_out is of shape (N, decoder_out_dim)
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
|
||||
for t in range(T):
|
||||
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
|
||||
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out.unsqueeze(2),
|
||||
decoder_out.unsqueeze(1),
|
||||
)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
streams[i].hyp.append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = torch.tensor(
|
||||
[stream.hyp[-context_size:] for stream in streams],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=False,
|
||||
)
|
||||
|
||||
hyp_tokens = []
|
||||
for stream in streams:
|
||||
hyp_tokens.append(stream.hyp)
|
||||
return hyp_tokens
|
||||
|
||||
|
||||
def fast_beam_search(
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
processed_lens: torch.Tensor,
|
||||
decoding_streams: k2.RnntDecodingStreams,
|
||||
) -> List[List[int]]:
|
||||
|
||||
B, T, C = encoder_out.shape
|
||||
for t in range(T):
|
||||
# shape is a RaggedShape of shape (B, context)
|
||||
# contexts is a Tensor of shape (shape.NumElements(), context_size)
|
||||
shape, contexts = decoding_streams.get_contexts()
|
||||
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
|
||||
contexts = contexts.to(torch.int64)
|
||||
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
|
||||
decoder_out = model.decoder(contexts, need_pad=False)
|
||||
# current_encoder_out is of shape
|
||||
# (shape.NumElements(), 1, joiner_dim)
|
||||
# fmt: off
|
||||
current_encoder_out = torch.index_select(
|
||||
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
|
||||
)
|
||||
# fmt: on
|
||||
logits = model.joiner(
|
||||
current_encoder_out.unsqueeze(2),
|
||||
decoder_out.unsqueeze(1),
|
||||
)
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
log_probs = logits.log_softmax(dim=-1)
|
||||
decoding_streams.advance(log_probs)
|
||||
|
||||
decoding_streams.terminate_and_flush_to_streams()
|
||||
|
||||
lattice = decoding_streams.format_output(processed_lens.tolist())
|
||||
best_path = one_best_decoding(lattice)
|
||||
hyp_tokens = get_texts(best_path)
|
||||
return hyp_tokens
|
||||
|
||||
|
||||
def decode_one_chunk(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
@ -305,8 +222,6 @@ def decode_one_chunk(
|
||||
features = []
|
||||
feature_lens = []
|
||||
states = []
|
||||
|
||||
rnnt_stream_list = []
|
||||
processed_lens = []
|
||||
|
||||
for stream in decode_streams:
|
||||
@ -317,8 +232,6 @@ def decode_one_chunk(
|
||||
feature_lens.append(feat_len)
|
||||
states.append(stream.states)
|
||||
processed_lens.append(stream.done_frames)
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
rnnt_stream_list.append(stream.rnnt_decoding_stream)
|
||||
|
||||
feature_lens = torch.tensor(feature_lens, device=device)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
|
||||
@ -330,19 +243,13 @@ def decode_one_chunk(
|
||||
# frames.
|
||||
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
|
||||
if features.size(1) < tail_length:
|
||||
feature_lens += tail_length - features.size(1)
|
||||
features = torch.cat(
|
||||
[
|
||||
features,
|
||||
torch.tensor(
|
||||
LOG_EPS, dtype=features.dtype, device=device
|
||||
).expand(
|
||||
features.size(0),
|
||||
tail_length - features.size(1),
|
||||
features.size(2),
|
||||
),
|
||||
],
|
||||
dim=1,
|
||||
pad_length = tail_length - features.size(1)
|
||||
feature_lens += pad_length
|
||||
features = torch.nn.functional.pad(
|
||||
features,
|
||||
(0, 0, 0, pad_length),
|
||||
mode="constant",
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
states = [
|
||||
@ -362,22 +269,31 @@ def decode_one_chunk(
|
||||
)
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp_tokens = greedy_search(model, encoder_out, decode_streams)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
config = k2.RnntDecodingConfig(
|
||||
vocab_size=params.vocab_size,
|
||||
decoder_history_len=params.context_size,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
greedy_search(
|
||||
model=model, encoder_out=encoder_out, streams=decode_streams
|
||||
)
|
||||
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
processed_lens = processed_lens + encoder_out_lens
|
||||
hyp_tokens = fast_beam_search(
|
||||
model, encoder_out, processed_lens, decoding_streams
|
||||
fast_beam_search_one_best(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
processed_lens=processed_lens,
|
||||
streams=decode_streams,
|
||||
beam=params.beam,
|
||||
max_states=params.max_states,
|
||||
max_contexts=params.max_contexts,
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
modified_beam_search(
|
||||
model=model,
|
||||
streams=decode_streams,
|
||||
encoder_out=encoder_out,
|
||||
num_active_paths=params.num_active_paths,
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
|
||||
states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
|
||||
|
||||
@ -385,8 +301,6 @@ def decode_one_chunk(
|
||||
for i in range(len(decode_streams)):
|
||||
decode_streams[i].states = [states[0][i], states[1][i]]
|
||||
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decode_streams[i].hyp = hyp_tokens[i]
|
||||
if decode_streams[i].done:
|
||||
finished_streams.append(i)
|
||||
|
||||
@ -442,6 +356,7 @@ def decode_dataset(
|
||||
# each utterance has a DecodeStream.
|
||||
decode_stream = DecodeStream(
|
||||
params=params,
|
||||
cut_id=cut.id,
|
||||
initial_states=initial_states,
|
||||
decoding_graph=decoding_graph,
|
||||
device=device,
|
||||
@ -469,13 +384,11 @@ def decode_dataset(
|
||||
params=params, model=model, decode_streams=decode_streams
|
||||
)
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
hyp = decode_streams[i].hyp
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = hyp[params.context_size :] # noqa
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].id,
|
||||
decode_streams[i].ground_truth.split(),
|
||||
sp.decode(hyp).split(),
|
||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
@ -489,24 +402,30 @@ def decode_dataset(
|
||||
params=params, model=model, decode_streams=decode_streams
|
||||
)
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
hyp = decode_streams[i].hyp
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = hyp[params.context_size :] # noqa
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].id,
|
||||
decode_streams[i].ground_truth.split(),
|
||||
sp.decode(hyp).split(),
|
||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
|
||||
key = "greedy_search"
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
if params.decoding_method == "greedy_search":
|
||||
key = "greedy_search"
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
key = (
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
key = f"num_active_paths_{params.num_active_paths}"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
|
||||
return {key: decode_results}
|
||||
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user