mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
Merge remote-tracking branch 'k2-fsa/master'
This commit is contained in:
commit
09f3e573b2
7
.flake8
7
.flake8
@ -4,12 +4,15 @@ statistics=true
|
||||
max-line-length = 80
|
||||
per-file-ignores =
|
||||
# line too long
|
||||
icefall/diagnostics.py: E501
|
||||
icefall/diagnostics.py: E501,
|
||||
egs/*/ASR/*/conformer.py: E501,
|
||||
egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
|
||||
egs/*/ASR/*/optim.py: E501,
|
||||
egs/*/ASR/*/scaling.py: E501,
|
||||
egs/librispeech/ASR/conv_emformer_transducer_stateless/*.py: E501, E203
|
||||
egs/librispeech/ASR/lstm_transducer_stateless/*.py: E501, E203
|
||||
egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203
|
||||
egs/librispeech/ASR/conformer_ctc2/*py: E501,
|
||||
egs/librispeech/ASR/RESULTS.md: E999,
|
||||
|
||||
# invalid escape sequence (cause by tex formular), W605
|
||||
icefall/utils.py: E501, W605
|
||||
|
@ -22,8 +22,80 @@ ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
pushd $repo/exp
|
||||
ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt
|
||||
ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
log "Test exporting to ONNX format"
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--onnx 1
|
||||
|
||||
log "Export to torchscript model"
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--jit 1
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--jit-trace 1
|
||||
|
||||
ls -lh $repo/exp/*.onnx
|
||||
ls -lh $repo/exp/*.pt
|
||||
|
||||
log "Decode with ONNX models"
|
||||
|
||||
./pruned_transducer_stateless3/onnx_check.py \
|
||||
--jit-filename $repo/exp/cpu_jit.pt \
|
||||
--onnx-encoder-filename $repo/exp/encoder.onnx \
|
||||
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
||||
--onnx-joiner-filename $repo/exp/joiner.onnx
|
||||
|
||||
./pruned_transducer_stateless3/onnx_check_all_in_one.py \
|
||||
--jit-filename $repo/exp/cpu_jit.pt \
|
||||
--onnx-all-in-one-filename $repo/exp/all_in_one.onnx
|
||||
|
||||
./pruned_transducer_stateless3/onnx_pretrained.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--encoder-model-filename $repo/exp/encoder.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner.onnx \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
|
||||
log "Decode with models exported by torch.jit.trace()"
|
||||
|
||||
./pruned_transducer_stateless3/jit_pretrained.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--encoder-model-filename $repo/exp/encoder_jit_trace.pt \
|
||||
--decoder-model-filename $repo/exp/decoder_jit_trace.pt \
|
||||
--joiner-model-filename $repo/exp/joiner_jit_trace.pt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
|
||||
log "Decode with models exported by torch.jit.script()"
|
||||
|
||||
./pruned_transducer_stateless3/jit_pretrained.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--encoder-model-filename $repo/exp/encoder_jit_script.pt \
|
||||
--decoder-model-filename $repo/exp/decoder_jit_script.pt \
|
||||
--joiner-model-filename $repo/exp/joiner_jit_script.pt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
|
||||
|
||||
for sym in 1 2 3; do
|
||||
log "Greedy search with --max-sym-per-frame $sym"
|
||||
|
||||
|
@ -70,7 +70,7 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
|
||||
max_duration=100
|
||||
|
||||
for method in greedy_search fast_beam_search modified_beam_search; do
|
||||
log "Decoding with $method"
|
||||
log "Simulate streaming decoding with $method"
|
||||
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--decoding-method $method \
|
||||
@ -82,5 +82,19 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
|
||||
--causal-convolution 1
|
||||
done
|
||||
|
||||
for method in greedy_search fast_beam_search modified_beam_search; do
|
||||
log "Real streaming decoding with $method"
|
||||
|
||||
./pruned_transducer_stateless2/streaming_decode.py \
|
||||
--decoding-method $method \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
--num-decode-streams 100 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--left-context 32 \
|
||||
--decode-chunk-size 8 \
|
||||
--right-context 0
|
||||
done
|
||||
|
||||
rm pruned_transducer_stateless2/exp/*.pt
|
||||
fi
|
||||
|
65
.github/workflows/build-doc.yml
vendored
Normal file
65
.github/workflows/build-doc.yml
vendored
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright 2022 Xiaomi Corp. (author: Fangjun Kuang)
|
||||
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# refer to https://github.com/actions/starter-workflows/pull/47/files
|
||||
|
||||
# You can access it at https://k2-fsa.github.io/icefall/
|
||||
name: Generate doc
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- doc
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
|
||||
jobs:
|
||||
build-doc:
|
||||
if: github.event.label.name == 'doc' || github.event_name == 'push'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: ["3.8"]
|
||||
steps:
|
||||
# refer to https://github.com/actions/checkout
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Display Python version
|
||||
run: python -c "import sys; print(sys.version)"
|
||||
|
||||
- name: Build doc
|
||||
shell: bash
|
||||
run: |
|
||||
cd docs
|
||||
python3 -m pip install -r ./requirements.txt
|
||||
make html
|
||||
touch build/html/.nojekyll
|
||||
|
||||
- name: Deploy
|
||||
uses: peaceiris/actions-gh-pages@v3
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
publish_dir: ./docs/build/html
|
||||
publish_branch: gh-pages
|
@ -35,7 +35,7 @@ on:
|
||||
|
||||
jobs:
|
||||
run_librispeech_pruned_transducer_stateless3_2022_05_13:
|
||||
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
|
2
.github/workflows/style_check.yml
vendored
2
.github/workflows/style_check.yml
vendored
@ -29,7 +29,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04, macos-10.15]
|
||||
os: [ubuntu-18.04, macos-latest]
|
||||
python-version: [3.7, 3.9]
|
||||
fail-fast: false
|
||||
|
||||
|
21
README.md
21
README.md
@ -10,6 +10,10 @@ using <https://github.com/k2-fsa/k2>.
|
||||
You can use <https://github.com/k2-fsa/sherpa> to deploy models
|
||||
trained with icefall.
|
||||
|
||||
You can try pre-trained models from within your browser without the need
|
||||
to download or install anything by visiting <https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition>
|
||||
See <https://k2-fsa.github.io/icefall/huggingface/spaces.html> for more details.
|
||||
|
||||
## Installation
|
||||
|
||||
Please refer to <https://icefall.readthedocs.io/en/latest/installation/index.html>
|
||||
@ -246,17 +250,25 @@ We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless mod
|
||||
|
||||
### WenetSpeech
|
||||
|
||||
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless2].
|
||||
We provide some models for this recipe: [Pruned stateless RNN-T_2: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless2] and [Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless5].
|
||||
|
||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset)
|
||||
#### Pruned stateless RNN-T_2: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset, offline ASR)
|
||||
|
||||
| | Dev | Test-Net | Test-Meeting |
|
||||
|----------------------|-------|----------|--------------|
|
||||
| greedy search | 7.80 | 8.75 | 13.49 |
|
||||
| fast beam search | 7.94 | 8.74 | 13.80 |
|
||||
| modified beam search| 7.76 | 8.71 | 13.41 |
|
||||
| fast beam search | 7.94 | 8.74 | 13.80 |
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing)
|
||||
#### Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset)
|
||||
**Streaming**:
|
||||
| | Dev | Test-Net | Test-Meeting |
|
||||
|----------------------|-------|----------|--------------|
|
||||
| greedy_search | 8.78 | 10.12 | 16.16 |
|
||||
| modified_beam_search | 8.53| 9.95 | 15.81 |
|
||||
| fast_beam_search| 9.01 | 10.47 | 16.28 |
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless2 model: [](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing)
|
||||
|
||||
### Alimeeting
|
||||
|
||||
@ -329,6 +341,7 @@ Please see: [ or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
"sphinx_rtd_theme",
|
||||
"sphinx.ext.todo",
|
||||
"sphinx_rtd_theme",
|
||||
"sphinxcontrib.youtube",
|
||||
]
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
|
13
docs/source/huggingface/index.rst
Normal file
13
docs/source/huggingface/index.rst
Normal file
@ -0,0 +1,13 @@
|
||||
Huggingface
|
||||
===========
|
||||
|
||||
This section describes how to find pre-trained models.
|
||||
It also demonstrates how to try them from within your browser
|
||||
without installing anything by using
|
||||
`Huggingface spaces <https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition>`_.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
pretrained-models
|
||||
spaces
|
BIN
docs/source/huggingface/pic/hugging-face-sherpa-2.png
Normal file
BIN
docs/source/huggingface/pic/hugging-face-sherpa-2.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 455 KiB |
BIN
docs/source/huggingface/pic/hugging-face-sherpa-3.png
Normal file
BIN
docs/source/huggingface/pic/hugging-face-sherpa-3.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 392 KiB |
BIN
docs/source/huggingface/pic/hugging-face-sherpa.png
Normal file
BIN
docs/source/huggingface/pic/hugging-face-sherpa.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 426 KiB |
17
docs/source/huggingface/pretrained-models.rst
Normal file
17
docs/source/huggingface/pretrained-models.rst
Normal file
@ -0,0 +1,17 @@
|
||||
Pre-trained models
|
||||
==================
|
||||
|
||||
We have uploaded pre-trained models for all recipes in ``icefall``
|
||||
to `<https://huggingface.co/>`_.
|
||||
|
||||
You can find them by visiting the following link:
|
||||
|
||||
`<https://huggingface.co/models?search=icefall>`_.
|
||||
|
||||
You can also find links of pre-trained models for a specific recipe
|
||||
by looking at the corresponding ``RESULTS.md``. For instance:
|
||||
|
||||
- `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>`_
|
||||
- `<https://github.com/k2-fsa/icefall/blob/master/egs/aishell/ASR/RESULTS.md>`_
|
||||
- `<https://github.com/k2-fsa/icefall/blob/master/egs/gigaspeech/ASR/RESULTS.md>`_
|
||||
- `<https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/RESULTS.md>`_
|
65
docs/source/huggingface/spaces.rst
Normal file
65
docs/source/huggingface/spaces.rst
Normal file
@ -0,0 +1,65 @@
|
||||
Huggingface spaces
|
||||
==================
|
||||
|
||||
We have integrated the server framework
|
||||
`sherpa <http://github.com/k2-fsa/sherpa>`_
|
||||
with `Huggingface spaces <https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition>`_
|
||||
so that you can try pre-trained models from within your browser
|
||||
without the need to download or install anything.
|
||||
|
||||
All you need is a browser, which can be run on Windows, macOS, Linux, or even on your
|
||||
iPad and your phone.
|
||||
|
||||
Start your browser and visit the following address:
|
||||
|
||||
`<https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition>`_
|
||||
|
||||
and you will see a page like the following screenshot:
|
||||
|
||||
.. image:: ./pic/hugging-face-sherpa.png
|
||||
:alt: screenshot of `<https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition>`_
|
||||
:target: https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition
|
||||
|
||||
You can:
|
||||
|
||||
1. Select a language for recognition. Currently, we provide pre-trained models
|
||||
from ``icefall`` for the following languages: ``Chinese``, ``English``, and
|
||||
``Chinese+English``.
|
||||
2. After selecting the target language, you can select a pre-trained model
|
||||
corresponding to the language.
|
||||
3. Select the decoding method. Currently, it provides ``greedy search``
|
||||
and ``modified_beam_search``.
|
||||
4. If you selected ``modified_beam_search``, you can choose the number of
|
||||
active paths during the search.
|
||||
5. Either upload a file or record your speech for recognition.
|
||||
6. Click the button ``Submit for recognition``.
|
||||
7. Wait for a moment and you will get the recognition results.
|
||||
|
||||
The following screenshot shows an example when selecting ``Chinese+English``:
|
||||
|
||||
.. image:: ./pic/hugging-face-sherpa-3.png
|
||||
:alt: screenshot of `<https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition>`_
|
||||
:target: https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition
|
||||
|
||||
|
||||
In the bottom part of the page, you can find a table of examples. You can click
|
||||
one of them and then click ``Submit for recognition``.
|
||||
|
||||
.. image:: ./pic/hugging-face-sherpa-2.png
|
||||
:alt: screenshot of `<https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition>`_
|
||||
:target: https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition
|
||||
|
||||
YouTube Video
|
||||
-------------
|
||||
|
||||
We provide the following YouTube video demonstrating how to use
|
||||
`<https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition>`_.
|
||||
|
||||
.. note::
|
||||
|
||||
To get the latest news of `next-gen Kaldi <https://github.com/k2-fsa>`_, please subscribe
|
||||
the following YouTube channel by `Nadira Povey <https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_:
|
||||
|
||||
`<https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_
|
||||
|
||||
.. youtube:: ElN3r9dkKE4
|
@ -23,3 +23,4 @@ speech recognition recipes using `k2 <https://github.com/k2-fsa/k2>`_.
|
||||
installation/index
|
||||
recipes/index
|
||||
contributing/index
|
||||
huggingface/index
|
||||
|
@ -474,3 +474,19 @@ The decoding log is:
|
||||
**Congratulations!** You have successfully setup the environment and have run the first recipe in ``icefall``.
|
||||
|
||||
Have fun with ``icefall``!
|
||||
|
||||
YouTube Video
|
||||
-------------
|
||||
|
||||
We provide the following YouTube video showing how to install ``icefall``.
|
||||
It also shows how to debug various problems that you may encounter while
|
||||
using ``icefall``.
|
||||
|
||||
.. note::
|
||||
|
||||
To get the latest news of `next-gen Kaldi <https://github.com/k2-fsa>`_, please subscribe
|
||||
the following YouTube channel by `Nadira Povey <https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_:
|
||||
|
||||
`<https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_
|
||||
|
||||
.. youtube:: LVmrBD0tLfE
|
||||
|
@ -70,6 +70,17 @@ To run stage 2 to stage 5, use:
|
||||
All generated files by ``./prepare.sh``, e.g., features, lexicon, etc,
|
||||
are saved in ``./data`` directory.
|
||||
|
||||
We provide the following YouTube video showing how to run ``./prepare.sh``.
|
||||
|
||||
.. note::
|
||||
|
||||
To get the latest news of `next-gen Kaldi <https://github.com/k2-fsa>`_, please subscribe
|
||||
the following YouTube channel by `Nadira Povey <https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_:
|
||||
|
||||
`<https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_
|
||||
|
||||
.. youtube:: ofEIoJL-mGM
|
||||
|
||||
|
||||
Training
|
||||
--------
|
||||
|
@ -45,6 +45,16 @@ To run stage 2 to stage 5, use:
|
||||
|
||||
$ ./prepare.sh --stage 2 --stop-stage 5
|
||||
|
||||
We provide the following YouTube video showing how to run ``./prepare.sh``.
|
||||
|
||||
.. note::
|
||||
|
||||
To get the latest news of `next-gen Kaldi <https://github.com/k2-fsa>`_, please subscribe
|
||||
the following YouTube channel by `Nadira Povey <https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_:
|
||||
|
||||
`<https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_
|
||||
|
||||
.. youtube:: ofEIoJL-mGM
|
||||
|
||||
Training
|
||||
--------
|
||||
|
@ -367,6 +367,7 @@ def decode_dataset(
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
texts = [list(str(text).replace(" ", "")) for text in texts]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -379,8 +380,8 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
this_batch.append((ref_text, hyp_words))
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
this_batch.append((cut_id, ref_text, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -405,6 +406,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -528,6 +530,8 @@ def main():
|
||||
from lhotse import CutSet
|
||||
from lhotse.dataset.webdataset import export_to_webdataset
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)
|
||||
|
||||
dev = "dev"
|
||||
|
@ -374,6 +374,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -389,9 +390,9 @@ def decode_dataset(
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
@ -419,6 +420,7 @@ def save_results(
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
@ -429,7 +431,9 @@ def save_results(
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((list("".join(res[0])), list("".join(res[1]))))
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
||||
@ -537,6 +541,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aishell = AishellAsrDataModule(args)
|
||||
test_cuts = aishell.test_cuts()
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
|
@ -386,6 +386,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -401,9 +402,9 @@ def decode_dataset(
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
@ -431,6 +432,7 @@ def save_results(
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
@ -441,7 +443,9 @@ def save_results(
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((list("".join(res[0])), list("".join(res[1]))))
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
||||
@ -556,6 +560,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aishell = AishellAsrDataModule(args)
|
||||
test_cuts = aishell.test_cuts()
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
|
@ -48,6 +48,8 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
log "stage -1: Download LM"
|
||||
# We assume that you have installed the git-lfs, if not, you could install it
|
||||
# using: `sudo apt-get install git-lfs && git-lfs install`
|
||||
git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1)
|
||||
|
||||
if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then
|
||||
git clone https://huggingface.co/pkufool/aishell_lm $dl_dir/lm
|
||||
fi
|
||||
|
@ -377,6 +377,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -389,9 +390,9 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -416,6 +417,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -427,7 +429,9 @@ def save_results(
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((list("".join(res[0])), list("".join(res[1]))))
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
@ -464,6 +468,7 @@ def main():
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params.datatang_prob = 0
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
@ -605,6 +610,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
asr_datamodule = AsrDataModule(args)
|
||||
aishell = AIShell(manifest_dir=args.manifest_dir)
|
||||
test_cuts = aishell.test_cuts()
|
||||
|
@ -157,6 +157,7 @@ def main():
|
||||
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
params.datatang_prob = 0
|
||||
|
||||
logging.info(params)
|
||||
|
||||
|
@ -223,6 +223,7 @@ def main():
|
||||
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
params.datatang_prob = 0
|
||||
|
||||
logging.info(params)
|
||||
|
||||
|
@ -22,8 +22,12 @@
|
||||
Usage:
|
||||
|
||||
./prepare.sh
|
||||
|
||||
# If you use a non-zero value for --datatang-prob, you also need to run
|
||||
./prepare_aidatatang_200zh.sh
|
||||
|
||||
If you use --datatang-prob=0, then you don't need to run the above script.
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
|
||||
@ -343,9 +347,12 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--datatang-prob",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="The probability to select a batch from the "
|
||||
"aidatatang_200zh dataset",
|
||||
default=0.0,
|
||||
help="""The probability to select a batch from the
|
||||
aidatatang_200zh dataset.
|
||||
If it is set to 0, you don't need to download the data
|
||||
for aidatatang_200zh.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
@ -457,8 +464,12 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
||||
if params.datatang_prob > 0:
|
||||
decoder_datatang = get_decoder_model(params)
|
||||
joiner_datatang = get_joiner_model(params)
|
||||
else:
|
||||
decoder_datatang = None
|
||||
joiner_datatang = None
|
||||
|
||||
model = Transducer(
|
||||
encoder=encoder,
|
||||
@ -726,7 +737,7 @@ def train_one_epoch(
|
||||
scheduler: LRSchedulerType,
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
datatang_train_dl: torch.utils.data.DataLoader,
|
||||
datatang_train_dl: Optional[torch.utils.data.DataLoader],
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
rng: random.Random,
|
||||
scaler: GradScaler,
|
||||
@ -778,13 +789,17 @@ def train_one_epoch(
|
||||
dl_weights = [1 - params.datatang_prob, params.datatang_prob]
|
||||
|
||||
iter_aishell = iter(train_dl)
|
||||
if datatang_train_dl is not None:
|
||||
iter_datatang = iter(datatang_train_dl)
|
||||
|
||||
batch_idx = 0
|
||||
|
||||
while True:
|
||||
if datatang_train_dl is not None:
|
||||
idx = rng.choices((0, 1), weights=dl_weights, k=1)[0]
|
||||
dl = iter_aishell if idx == 0 else iter_datatang
|
||||
else:
|
||||
dl = iter_aishell
|
||||
|
||||
try:
|
||||
batch = next(dl)
|
||||
@ -808,7 +823,11 @@ def train_one_epoch(
|
||||
warmup=(params.batch_idx_train / params.model_warm_step),
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
if datatang_train_dl is not None:
|
||||
tot_loss = (
|
||||
tot_loss * (1 - 1 / params.reset_interval)
|
||||
) + loss_info
|
||||
|
||||
if aishell:
|
||||
aishell_tot_loss = (
|
||||
aishell_tot_loss * (1 - 1 / params.reset_interval)
|
||||
@ -871,12 +890,21 @@ def train_one_epoch(
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
if datatang_train_dl is not None:
|
||||
datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], "
|
||||
tot_loss_str = (
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
)
|
||||
else:
|
||||
tot_loss_str = ""
|
||||
datatang_str = ""
|
||||
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, {prefix}_loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
f"{tot_loss_str}"
|
||||
f"aishell_tot_loss[{aishell_tot_loss}], "
|
||||
f"datatang_tot_loss[{datatang_tot_loss}], "
|
||||
f"{datatang_str}"
|
||||
f"batch size: {batch_size}, "
|
||||
f"lr: {cur_lr:.2e}"
|
||||
)
|
||||
@ -891,12 +919,15 @@ def train_one_epoch(
|
||||
f"train/current_{prefix}_",
|
||||
params.batch_idx_train,
|
||||
)
|
||||
if datatang_train_dl is not None:
|
||||
# If it is None, tot_loss is the same as aishell_tot_loss.
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
aishell_tot_loss.write_summary(
|
||||
tb_writer, "train/aishell_tot_", params.batch_idx_train
|
||||
)
|
||||
if datatang_train_dl is not None:
|
||||
datatang_tot_loss.write_summary(
|
||||
tb_writer, "train/datatang_tot_", params.batch_idx_train
|
||||
)
|
||||
@ -917,7 +948,10 @@ def train_one_epoch(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
if datatang_train_dl is not None:
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
else:
|
||||
loss_value = aishell_tot_loss["loss"] / aishell_tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
@ -1004,7 +1038,16 @@ def run(rank, world_size, args):
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
if params.datatang_prob > 0:
|
||||
find_unused_parameters = True
|
||||
else:
|
||||
find_unused_parameters = False
|
||||
|
||||
model = DDP(
|
||||
model,
|
||||
device_ids=[rank],
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
)
|
||||
|
||||
optimizer = Eve(model.parameters(), lr=params.initial_lr)
|
||||
|
||||
@ -1032,11 +1075,6 @@ def run(rank, world_size, args):
|
||||
train_cuts = aishell.train_cuts()
|
||||
train_cuts = filter_short_and_long_utterances(train_cuts)
|
||||
|
||||
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
|
||||
train_datatang_cuts = datatang.train_cuts()
|
||||
train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts)
|
||||
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
|
||||
|
||||
if args.enable_musan:
|
||||
cuts_musan = load_manifest(
|
||||
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
|
||||
@ -1052,11 +1090,21 @@ def run(rank, world_size, args):
|
||||
cuts_musan=cuts_musan,
|
||||
)
|
||||
|
||||
if params.datatang_prob > 0:
|
||||
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
|
||||
train_datatang_cuts = datatang.train_cuts()
|
||||
train_datatang_cuts = filter_short_and_long_utterances(
|
||||
train_datatang_cuts
|
||||
)
|
||||
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
|
||||
datatang_train_dl = asr_datamodule.train_dataloaders(
|
||||
train_datatang_cuts,
|
||||
on_the_fly_feats=False,
|
||||
cuts_musan=cuts_musan,
|
||||
)
|
||||
else:
|
||||
datatang_train_dl = None
|
||||
logging.info("Not using aidatatang_200zh for training")
|
||||
|
||||
valid_cuts = aishell.valid_cuts()
|
||||
valid_dl = asr_datamodule.valid_dataloaders(valid_cuts)
|
||||
@ -1065,6 +1113,7 @@ def run(rank, world_size, args):
|
||||
train_dl,
|
||||
# datatang_train_dl
|
||||
]:
|
||||
if dl is not None:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=dl,
|
||||
@ -1083,6 +1132,7 @@ def run(rank, world_size, args):
|
||||
scheduler.step_epoch(epoch - 1)
|
||||
fix_random_seed(params.seed + epoch - 1)
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
if datatang_train_dl is not None:
|
||||
datatang_train_dl.sampler.set_epoch(epoch)
|
||||
|
||||
if tb_writer is not None:
|
||||
|
@ -241,6 +241,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -253,9 +254,9 @@ def decode_dataset(
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
@ -278,6 +279,7 @@ def save_results(
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -287,7 +289,9 @@ def save_results(
|
||||
# We compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((list("".join(res[0])), list("".join(res[1]))))
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(f, f"{test_set_name}-{key}", results_char)
|
||||
test_set_wers[key] = wer
|
||||
@ -365,6 +369,8 @@ def main():
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aishell = AishellAsrDataModule(args)
|
||||
test_cuts = aishell.test_cuts()
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
|
@ -38,8 +38,8 @@ from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
@ -296,6 +296,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -307,9 +308,9 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -334,6 +335,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
@ -344,7 +346,9 @@ def save_results(
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((list("".join(res[0])), list("".join(res[1]))))
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
@ -438,6 +442,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aishell = AishellAsrDataModule(args)
|
||||
test_cuts = aishell.test_cuts()
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
|
@ -341,6 +341,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -353,9 +354,9 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -380,6 +381,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -391,7 +393,9 @@ def save_results(
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((list("".join(res[0])), list("".join(res[1]))))
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
@ -496,6 +500,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
asr_datamodule = AsrDataModule(args)
|
||||
aishell = AIShell(manifest_dir=args.manifest_dir)
|
||||
test_cuts = aishell.test_cuts()
|
||||
|
@ -345,6 +345,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -357,9 +358,9 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -384,6 +385,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -395,7 +397,9 @@ def save_results(
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((list("".join(res[0])), list("".join(res[1]))))
|
||||
results_char.append(
|
||||
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
@ -498,6 +502,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aishell = AishellAsrDataModule(args)
|
||||
test_cuts = aishell.test_cuts()
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
|
19
egs/aishell2/ASR/README.md
Normal file
19
egs/aishell2/ASR/README.md
Normal file
@ -0,0 +1,19 @@
|
||||
|
||||
# Introduction
|
||||
|
||||
This recipe includes some different ASR models trained with Aishell2.
|
||||
|
||||
[./RESULTS.md](./RESULTS.md) contains the latest results.
|
||||
|
||||
# Transducers
|
||||
|
||||
There are various folders containing the name `transducer` in this folder.
|
||||
The following table lists the differences among them.
|
||||
|
||||
| | Encoder | Decoder | Comment |
|
||||
|---------------------------------------|---------------------|--------------------|-----------------------------|
|
||||
| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless5 in librispeech recipe |
|
||||
|
||||
The decoder in `transducer_stateless` is modified from the paper
|
||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||
We place an additional Conv1d layer right after the input embedding layer.
|
89
egs/aishell2/ASR/RESULTS.md
Normal file
89
egs/aishell2/ASR/RESULTS.md
Normal file
@ -0,0 +1,89 @@
|
||||
## Results
|
||||
|
||||
### Aishell2 char-based training results (Pruned Transducer 5)
|
||||
|
||||
#### 2022-07-11
|
||||
|
||||
Using the codes from this commit https://github.com/k2-fsa/icefall/pull/465.
|
||||
|
||||
When training with context size equals to 1, the WERs are
|
||||
|
||||
| | dev-ios | test-ios | comment |
|
||||
|------------------------------------|-------|----------|----------------------------------|
|
||||
| greedy search | 5.57 | 5.89 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
| modified beam search (beam size 4) | 5.32 | 5.56 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
| fast beam search (set as default) | 5.5 | 5.78 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
| fast beam search nbest | 5.46 | 5.74 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
| fast beam search oracle | 1.92 | 2.2 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
| fast beam search nbest LG | 5.59 | 5.93 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
|
||||
The training command for reproducing is given below:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./pruned_transducer_stateless5/train.py \
|
||||
--world-size 4 \
|
||||
--lang-dir data/lang_char \
|
||||
--num-epochs 40 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir /result \
|
||||
--max-duration 300 \
|
||||
--use-fp16 0 \
|
||||
--num-encoder-layers 24 \
|
||||
--dim-feedforward 1536 \
|
||||
--nhead 8 \
|
||||
--encoder-dim 384 \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--context-size 1
|
||||
```
|
||||
|
||||
The decoding command is:
|
||||
```bash
|
||||
for method in greedy_search modified_beam_search \
|
||||
fast_beam_search fast_beam_search_nbest \
|
||||
fast_beam_search_nbest_oracle fast_beam_search_nbest_LG; do
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 25 \
|
||||
--avg 5 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method $method \
|
||||
--max-sym-per-frame 1 \
|
||||
--num-encoder-layers 24 \
|
||||
--dim-feedforward 1536 \
|
||||
--nhead 8 \
|
||||
--encoder-dim 384 \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--context-size 1 \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5 \
|
||||
--context-size 1 \
|
||||
--use-averaged-model True
|
||||
done
|
||||
```
|
||||
The tensorboard training log can be found at
|
||||
https://tensorboard.dev/experiment/RXyX4QjQQVKjBS2eQ2Qajg/#scalars
|
||||
|
||||
A pre-trained model and decoding logs can be found at <https://huggingface.co/yuekai/icefall-asr-aishell2-pruned-transducer-stateless5-B-2022-07-12>
|
||||
|
||||
When training with context size equals to 2, the WERs are
|
||||
|
||||
| | dev-ios | test-ios | comment |
|
||||
|------------------------------------|-------|----------|----------------------------------|
|
||||
| greedy search | 5.47 | 5.81 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
| modified beam search (beam size 4) | 5.38 | 5.61 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
| fast beam search (set as default) | 5.36 | 5.61 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
| fast beam search nbest | 5.37 | 5.6 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
| fast beam search oracle | 2.04 | 2.2 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
| fast beam search nbest LG | 5.59 | 5.82 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
|
||||
The tensorboard training log can be found at
|
||||
https://tensorboard.dev/experiment/5AxJ8LHoSre8kDAuLp4L7Q/#scalars
|
||||
|
||||
A pre-trained model and decoding logs can be found at <https://huggingface.co/yuekai/icefall-asr-aishell2-pruned-transducer-stateless5-A-2022-07-12>
|
0
egs/aishell2/ASR/local/__init__.py
Executable file
0
egs/aishell2/ASR/local/__init__.py
Executable file
1
egs/aishell2/ASR/local/compile_lg.py
Symbolic link
1
egs/aishell2/ASR/local/compile_lg.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/compile_lg.py
|
114
egs/aishell2/ASR/local/compute_fbank_aishell2.py
Executable file
114
egs/aishell2/ASR/local/compute_fbank_aishell2.py
Executable file
@ -0,0 +1,114 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the aishell2 dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aishell2(num_mel_bins: int = 80):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
|
||||
dataset_parts = (
|
||||
"train",
|
||||
"dev",
|
||||
"test",
|
||||
)
|
||||
prefix = "aishell2"
|
||||
suffix = "jsonl.gz"
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
continue
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition:
|
||||
cut_set = (
|
||||
cut_set
|
||||
+ cut_set.perturb_speed(0.9)
|
||||
+ cut_set.perturb_speed(1.1)
|
||||
)
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--num-mel-bins",
|
||||
type=int,
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_aishell2(num_mel_bins=args.num_mel_bins)
|
1
egs/aishell2/ASR/local/compute_fbank_musan.py
Symbolic link
1
egs/aishell2/ASR/local/compute_fbank_musan.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/compute_fbank_musan.py
|
96
egs/aishell2/ASR/local/display_manifest_statistics.py
Executable file
96
egs/aishell2/ASR/local/display_manifest_statistics.py
Executable file
@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file displays duration statistics of utterances in a manifest.
|
||||
You can use the displayed value to choose minimum/maximum duration
|
||||
to remove short and long utterances during the training.
|
||||
|
||||
See the function `remove_short_and_long_utt()` in transducer_stateless/train.py
|
||||
for usage.
|
||||
"""
|
||||
|
||||
|
||||
from lhotse import load_manifest_lazy
|
||||
|
||||
|
||||
def main():
|
||||
paths = [
|
||||
"./data/fbank/aishell2_cuts_train.jsonl.gz",
|
||||
"./data/fbank/aishell2_cuts_dev.jsonl.gz",
|
||||
"./data/fbank/aishell2_cuts_test.jsonl.gz",
|
||||
]
|
||||
|
||||
for path in paths:
|
||||
print(f"Starting display the statistics for {path}")
|
||||
cuts = load_manifest_lazy(path)
|
||||
cuts.describe()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
"""
|
||||
Starting display the statistics for ./data/fbank/aishell2_cuts_train.jsonl.gz
|
||||
Cuts count: 3026106
|
||||
Total duration (hours): 3021.2
|
||||
Speech duration (hours): 3021.2 (100.0%)
|
||||
***
|
||||
Duration statistics (seconds):
|
||||
mean 3.6
|
||||
std 1.5
|
||||
min 0.3
|
||||
25% 2.4
|
||||
50% 3.3
|
||||
75% 4.4
|
||||
99% 8.2
|
||||
99.5% 8.9
|
||||
99.9% 10.6
|
||||
max 21.5
|
||||
Starting display the statistics for ./data/fbank/aishell2_cuts_dev.jsonl.gz
|
||||
Cuts count: 2500
|
||||
Total duration (hours): 2.0
|
||||
Speech duration (hours): 2.0 (100.0%)
|
||||
***
|
||||
Duration statistics (seconds):
|
||||
mean 2.9
|
||||
std 1.0
|
||||
min 1.1
|
||||
25% 2.2
|
||||
50% 2.7
|
||||
75% 3.4
|
||||
99% 6.3
|
||||
99.5% 6.7
|
||||
99.9% 7.8
|
||||
max 9.4
|
||||
Starting display the statistics for ./data/fbank/aishell2_cuts_test.jsonl.gz
|
||||
Cuts count: 5000
|
||||
Total duration (hours): 4.0
|
||||
Speech duration (hours): 4.0 (100.0%)
|
||||
***
|
||||
Duration statistics (seconds):
|
||||
mean 2.9
|
||||
std 1.0
|
||||
min 1.1
|
||||
25% 2.2
|
||||
50% 2.7
|
||||
75% 3.3
|
||||
99% 6.2
|
||||
99.5% 6.6
|
||||
99.9% 7.7
|
||||
max 8.5
|
||||
"""
|
1
egs/aishell2/ASR/local/prepare_char.py
Symbolic link
1
egs/aishell2/ASR/local/prepare_char.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../aidatatang_200zh/ASR/local/prepare_char.py
|
1
egs/aishell2/ASR/local/prepare_lang.py
Symbolic link
1
egs/aishell2/ASR/local/prepare_lang.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../wenetspeech/ASR/local/prepare_lang.py
|
1
egs/aishell2/ASR/local/prepare_words.py
Symbolic link
1
egs/aishell2/ASR/local/prepare_words.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../wenetspeech/ASR/local/prepare_words.py
|
1
egs/aishell2/ASR/local/text2segments.py
Symbolic link
1
egs/aishell2/ASR/local/text2segments.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../wenetspeech/ASR/local/text2segments.py
|
1
egs/aishell2/ASR/local/text2token.py
Symbolic link
1
egs/aishell2/ASR/local/text2token.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../aidatatang_200zh/ASR/local/text2token.py
|
181
egs/aishell2/ASR/prepare.sh
Executable file
181
egs/aishell2/ASR/prepare.sh
Executable file
@ -0,0 +1,181 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
nj=30
|
||||
stage=0
|
||||
stop_stage=5
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files. If not, you need to apply aishell2 through
|
||||
# their official website.
|
||||
# https://www.aishelltech.com/aishell_2
|
||||
#
|
||||
# - $dl_dir/aishell2
|
||||
#
|
||||
#
|
||||
# - $dl_dir/musan
|
||||
# This directory contains the following directories downloaded from
|
||||
# http://www.openslr.org/17/
|
||||
#
|
||||
# - music
|
||||
# - noise
|
||||
# - speech
|
||||
|
||||
dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "stage 0: Download data"
|
||||
|
||||
# If you have pre-downloaded it to /path/to/aishell2,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/aishell2 $dl_dir/aishell2
|
||||
#
|
||||
# The directory structure is
|
||||
# aishell2/
|
||||
# |-- AISHELL-2
|
||||
# | |-- iOS
|
||||
# |-- data
|
||||
# |-- wav
|
||||
# |-- trans.txt
|
||||
# |-- dev
|
||||
# |-- wav
|
||||
# |-- trans.txt
|
||||
# |-- test
|
||||
# |-- wav
|
||||
# |-- trans.txt
|
||||
|
||||
|
||||
# If you have pre-downloaded it to /path/to/musan,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/musan $dl_dir/musan
|
||||
#
|
||||
if [ ! -d $dl_dir/musan ]; then
|
||||
lhotse download musan $dl_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare aishell2 manifest"
|
||||
# We assume that you have downloaded and unzip the aishell2 corpus
|
||||
# to $dl_dir/aishell2
|
||||
if [ ! -f data/manifests/.aishell2_manifests.done ]; then
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare aishell2 $dl_dir/aishell2 data/manifests -j $nj
|
||||
touch data/manifests/.aishell2_manifests.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Prepare musan manifest"
|
||||
# We assume that you have downloaded the musan corpus
|
||||
# to data/musan
|
||||
if [ ! -f data/manifests/.musan_manifests.done ]; then
|
||||
log "It may take 6 minutes"
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare musan $dl_dir/musan data/manifests
|
||||
touch data/manifests/.musan_manifests.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Compute fbank for aishell2"
|
||||
if [ ! -f data/fbank/.aishell2.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell2.py
|
||||
touch data/fbank/.aishell2.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Compute fbank for musan"
|
||||
if [ ! -f data/fbank/.msuan.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_musan.py
|
||||
touch data/fbank/.msuan.done
|
||||
fi
|
||||
fi
|
||||
|
||||
lang_char_dir=data/lang_char
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Prepare char based lang"
|
||||
mkdir -p $lang_char_dir
|
||||
|
||||
# Prepare text.
|
||||
# Note: in Linux, you can install jq with the following command:
|
||||
# 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
|
||||
# 2. chmod +x ./jq
|
||||
# 3. cp jq /usr/bin
|
||||
if [ ! -f $lang_char_dir/text ]; then
|
||||
gunzip -c data/manifests/aishell2_supervisions_train.jsonl.gz \
|
||||
| jq '.text' | sed 's/"//g' \
|
||||
| ./local/text2token.py -t "char" > $lang_char_dir/text
|
||||
fi
|
||||
|
||||
# The implementation of chinese word segmentation for text,
|
||||
# and it will take about 15 minutes.
|
||||
# If you can't install paddle-tiny with python 3.8, please refer to
|
||||
# https://github.com/fxsjy/jieba/issues/920
|
||||
if [ ! -f $lang_char_dir/text_words_segmentation ]; then
|
||||
python3 ./local/text2segments.py \
|
||||
--input-file $lang_char_dir/text \
|
||||
--output-file $lang_char_dir/text_words_segmentation
|
||||
fi
|
||||
|
||||
cat $lang_char_dir/text_words_segmentation | sed 's/ /\n/g' \
|
||||
| sort -u | sed '/^$/d' | uniq > $lang_char_dir/words_no_ids.txt
|
||||
|
||||
if [ ! -f $lang_char_dir/words.txt ]; then
|
||||
python3 ./local/prepare_words.py \
|
||||
--input-file $lang_char_dir/words_no_ids.txt \
|
||||
--output-file $lang_char_dir/words.txt
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
|
||||
python3 ./local/prepare_char.py
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare G"
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
|
||||
if [ ! -f ${lang_char_dir}/3-gram.unpruned.arpa ]; then
|
||||
./shared/make_kn_lm.py \
|
||||
-ngram-order 3 \
|
||||
-text $lang_char_dir/text_words_segmentation \
|
||||
-lm $lang_char_dir/3-gram.unpruned.arpa
|
||||
fi
|
||||
|
||||
mkdir -p data/lm
|
||||
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
|
||||
# It is used in building LG
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="$lang_char_dir/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=3 \
|
||||
$lang_char_dir/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 7: Compile LG"
|
||||
./local/compile_lg.py --lang-dir $lang_char_dir
|
||||
fi
|
0
egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py
Executable file
0
egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py
Executable file
418
egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
Executable file
418
egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
Executable file
@ -0,0 +1,418 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class AiShell2AsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. ios, android, mic).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||
)
|
||||
transforms.append(
|
||||
CutMix(
|
||||
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(
|
||||
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||
)
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SpeechRecognitionDataset(
|
||||
input_strategy=eval(self.args.input_strategy)(),
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to gen cuts from aishell2_cuts_train.jsonl.gz")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "aishell2_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to gen cuts from aishell2_cuts_test.jsonl.gz")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "aishell2_cuts_test.jsonl.gz"
|
||||
)
|
1
egs/aishell2/ASR/pruned_transducer_stateless5/beam_search.py
Symbolic link
1
egs/aishell2/ASR/pruned_transducer_stateless5/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
|
1
egs/aishell2/ASR/pruned_transducer_stateless5/conformer.py
Symbolic link
1
egs/aishell2/ASR/pruned_transducer_stateless5/conformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py
|
795
egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
Executable file
795
egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
Executable file
@ -0,0 +1,795 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 25 \
|
||||
--avg 5 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 25 \
|
||||
--avg 5 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 25 \
|
||||
--avg 5 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search (one best)
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 25 \
|
||||
--avg 5 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
|
||||
(5) fast beam search (nbest)
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 25 \
|
||||
--avg 5 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(6) fast beam search (nbest oracle WER)
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 25 \
|
||||
--avg 5 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(7) fast beam search (with LG)
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 25 \
|
||||
--avg 5 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_LG \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AiShell2AsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless5/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_char",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
lexicon: Lexicon,
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
sentence = "".join([lexicon.word_table[i] for i in hyp])
|
||||
hyps.append(list(sentence))
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp])
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
lexicon: Lexicon,
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
lexicon=lexicon,
|
||||
graph_compiler=graph_compiler,
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
this_batch.append((cut_id, ref_text, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AiShell2AsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.unk_id = lexicon.token_table["<unk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
decoding_graph = k2.trivial_graph(
|
||||
params.vocab_size - 1, device=device
|
||||
)
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aishell2 = AiShell2AsrDataModule(args)
|
||||
|
||||
valid_cuts = aishell2.valid_cuts()
|
||||
test_cuts = aishell2.test_cuts()
|
||||
|
||||
# use ios sets for dev and test
|
||||
dev_dl = aishell2.valid_dataloaders(valid_cuts)
|
||||
test_dl = aishell2.test_dataloaders(test_cuts)
|
||||
|
||||
test_sets = ["dev", "test"]
|
||||
test_dl = [dev_dl, test_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
lexicon=lexicon,
|
||||
graph_compiler=graph_compiler,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/aishell2/ASR/pruned_transducer_stateless5/decoder.py
Symbolic link
1
egs/aishell2/ASR/pruned_transducer_stateless5/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py
|
274
egs/aishell2/ASR/pruned_transducer_stateless5/export.py
Executable file
274
egs/aishell2/ASR/pruned_transducer_stateless5/export.py
Executable file
@ -0,0 +1,274 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
Usage:
|
||||
./pruned_transducer_stateless5/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--lang-dir data/lang_char
|
||||
--epoch 25 \
|
||||
--avg 5
|
||||
|
||||
It will generate a file exp_dir/pretrained.pt
|
||||
|
||||
To use the generated file with `pruned_transducer_stateless5/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/aishell2/ASR
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--lang-dir data/lang_char
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless5/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_char",
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.unk_id = lexicon.token_table["<unk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torch.jit.script")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/aishell2/ASR/pruned_transducer_stateless5/joiner.py
Symbolic link
1
egs/aishell2/ASR/pruned_transducer_stateless5/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py
|
1
egs/aishell2/ASR/pruned_transducer_stateless5/model.py
Symbolic link
1
egs/aishell2/ASR/pruned_transducer_stateless5/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/model.py
|
1
egs/aishell2/ASR/pruned_transducer_stateless5/optim.py
Symbolic link
1
egs/aishell2/ASR/pruned_transducer_stateless5/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/optim.py
|
342
egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
Executable file
342
egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
Executable file
@ -0,0 +1,342 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless5/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||
--lang-dir ./data/lang_char \
|
||||
--method greedy_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) modified beam search
|
||||
./pruned_transducer_stateless5/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||
--lang-dir ./data/lang_char \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) fast beam search
|
||||
./pruned_transducer_stateless5/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||
--lang-dir ./data/lang_char \
|
||||
--method fast_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`.
|
||||
|
||||
Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by
|
||||
./pruned_transducer_stateless5/export.py
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Path to lang.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame. Used only when
|
||||
--method is greedy_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.unk_id = lexicon.token_table["<unk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features, x_lens=feature_lengths
|
||||
)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyps = []
|
||||
msg = f"Using {params.method}"
|
||||
if params.method == "beam_search":
|
||||
msg += f" with beam size {params.beam_size}"
|
||||
logging.info(msg)
|
||||
|
||||
if params.method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
else:
|
||||
for i in range(num_waves):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp])
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = "".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/aishell2/ASR/pruned_transducer_stateless5/scaling.py
Symbolic link
1
egs/aishell2/ASR/pruned_transducer_stateless5/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py
|
1131
egs/aishell2/ASR/pruned_transducer_stateless5/train.py
Executable file
1131
egs/aishell2/ASR/pruned_transducer_stateless5/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/aishell2/ASR/shared
Symbolic link
1
egs/aishell2/ASR/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../../icefall/shared/
|
File diff suppressed because it is too large
Load Diff
1
egs/aishell4/ASR/pruned_transducer_stateless5/conformer.py
Symbolic link
1
egs/aishell4/ASR/pruned_transducer_stateless5/conformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py
|
@ -378,6 +378,7 @@ def decode_dataset(
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
texts = [list(str(text).replace(" ", "")) for text in texts]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -390,8 +391,8 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
this_batch.append((ref_text, hyp_words))
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
this_batch.append((cut_id, ref_text, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -416,6 +417,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -523,7 +525,9 @@ def main():
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
@ -534,7 +538,9 @@ def main():
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
@ -562,7 +568,8 @@ def main():
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
),
|
||||
strict=False,
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
@ -580,7 +587,8 @@ def main():
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
),
|
||||
strict=False,
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
@ -601,6 +609,8 @@ def main():
|
||||
c.supervisions[0].text = text_normalize(text)
|
||||
return c
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aishell4 = Aishell4AsrDataModule(args)
|
||||
test_cuts = aishell4.test_cuts()
|
||||
test_cuts = test_cuts.map(text_normalize_for_cut)
|
||||
|
@ -184,7 +184,9 @@ def main():
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
@ -195,7 +197,9 @@ def main():
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
@ -223,7 +227,8 @@ def main():
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
),
|
||||
strict=False,
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
@ -241,7 +246,8 @@ def main():
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
),
|
||||
strict=False,
|
||||
)
|
||||
|
||||
model.eval()
|
||||
|
@ -367,6 +367,7 @@ def decode_dataset(
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
texts = [list(str(text).replace(" ", "")) for text in texts]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -379,8 +380,8 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
this_batch.append((ref_text, hyp_words))
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
this_batch.append((cut_id, ref_text, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -405,6 +406,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -535,6 +537,8 @@ def main():
|
||||
from lhotse import CutSet
|
||||
from lhotse.dataset.webdataset import export_to_webdataset
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
alimeeting = AlimeetingAsrDataModule(args)
|
||||
|
||||
dev = "eval"
|
||||
|
@ -451,6 +451,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -469,9 +470,9 @@ def decode_dataset(
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
else:
|
||||
@ -512,6 +513,7 @@ def save_results(
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
results = post_processing(results)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
@ -676,6 +678,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
gigaspeech = GigaSpeechAsrDataModule(args)
|
||||
|
||||
dev_cuts = gigaspeech.dev_cuts()
|
||||
|
@ -20,11 +20,7 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
)
|
||||
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -69,6 +65,7 @@ def compute_fbank_gigaspeech_dev_test():
|
||||
storage_path=f"{in_out_dir}/feats_{partition}",
|
||||
num_workers=num_workers,
|
||||
batch_duration=batch_duration,
|
||||
overwrite=True,
|
||||
)
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False, min_duration=None
|
||||
|
@ -22,11 +22,7 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
)
|
||||
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -120,6 +116,7 @@ def compute_fbank_gigaspeech_splits(args):
|
||||
storage_path=f"{output_dir}/feats_XL_{idx}",
|
||||
num_workers=args.num_workers,
|
||||
batch_duration=args.batch_duration,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
logging.info("About to split cuts into smaller chunks.")
|
||||
|
@ -374,6 +374,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -386,9 +387,9 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -414,6 +415,7 @@ def save_results(
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = post_processing(results)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -544,6 +546,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
gigaspeech = GigaSpeechAsrDataModule(args)
|
||||
|
||||
dev_cuts = gigaspeech.dev_cuts()
|
||||
|
@ -23,8 +23,9 @@ The following table lists the differences among them.
|
||||
| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner|
|
||||
| `pruned_transducer_stateless6` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + distillation with hubert|
|
||||
| `pruned_stateless_emformer_rnnt2` | Emformer(from torchaudio) | Embedding + Conv1d | Using Emformer from torchaudio for streaming ASR|
|
||||
| `conv_emformer_transducer_stateless` | Emformer | Embedding + Conv1d | Using Emformer augmented with convolution for streaming ASR + mechanisms in reworked model |
|
||||
|
||||
| `conv_emformer_transducer_stateless` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer for streaming ASR + mechanisms in reworked model |
|
||||
| `conv_emformer_transducer_stateless2` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer with simplified memory for streaming ASR + mechanisms in reworked model |
|
||||
| `lstm_transducer_stateless` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model |
|
||||
|
||||
The decoder in `transducer_stateless` is modified from the paper
|
||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||
|
@ -1,5 +1,403 @@
|
||||
## Results
|
||||
|
||||
#### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T)
|
||||
|
||||
[lstm_transducer_stateless](./lstm_transducer_stateless)
|
||||
|
||||
It implements LSTM model with mechanisms in reworked model for streaming ASR.
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/479> for more details.
|
||||
|
||||
#### training on full librispeech
|
||||
|
||||
This model contains 12 encoder layers (LSTM module + Feedforward module). The number of model parameters is 84689496.
|
||||
|
||||
The WERs are:
|
||||
|
||||
| | test-clean | test-other | comment | decoding mode |
|
||||
|-------------------------------------|------------|------------|----------------------|----------------------|
|
||||
| greedy search (max sym per frame 1) | 3.81 | 9.73 | --epoch 35 --avg 15 | simulated streaming |
|
||||
| greedy search (max sym per frame 1) | 3.78 | 9.79 | --epoch 35 --avg 15 | streaming |
|
||||
| fast beam search | 3.74 | 9.59 | --epoch 35 --avg 15 | simulated streaming |
|
||||
| fast beam search | 3.73 | 9.61 | --epoch 35 --avg 15 | streaming |
|
||||
| modified beam search | 3.64 | 9.55 | --epoch 35 --avg 15 | simulated streaming |
|
||||
| modified beam search | 3.65 | 9.51 | --epoch 35 --avg 15 | streaming |
|
||||
|
||||
Note: `simulated streaming` indicates feeding full utterance during decoding, while `streaming` indicates feeding certain number of frames at each time.
|
||||
|
||||
The training command is:
|
||||
|
||||
```bash
|
||||
./lstm_transducer_stateless/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 35 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir lstm_transducer_stateless/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 500 \
|
||||
--master-port 12321 \
|
||||
--num-encoder-layers 12 \
|
||||
--rnn-hidden-size 1024
|
||||
```
|
||||
|
||||
The tensorboard log can be found at
|
||||
<https://tensorboard.dev/experiment/FWrM20mjTeWo6dTpFYOsYQ/>
|
||||
|
||||
The simulated streaming decoding command using greedy search, fast beam search, and modified beam search is:
|
||||
```bash
|
||||
for decoding_method in greedy_search fast_beam_search modified_beam_search; do
|
||||
./lstm_transducer_stateless/decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir lstm_transducer_stateless/exp \
|
||||
--max-duration 600 \
|
||||
--num-encoder-layers 12 \
|
||||
--rnn-hidden-size 1024 \
|
||||
--decoding-method $decoding_method \
|
||||
--use-averaged-model True \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8 \
|
||||
--beam-size 4
|
||||
done
|
||||
```
|
||||
|
||||
The streaming decoding command using greedy search, fast beam search, and modified beam search is:
|
||||
```bash
|
||||
for decoding_method in greedy_search fast_beam_search modified_beam_search; do
|
||||
./lstm_transducer_stateless/streaming_decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir lstm_transducer_stateless/exp \
|
||||
--max-duration 600 \
|
||||
--num-encoder-layers 12 \
|
||||
--rnn-hidden-size 1024 \
|
||||
--decoding-method $decoding_method \
|
||||
--use-averaged-model True \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8 \
|
||||
--beam-size 4
|
||||
done
|
||||
```
|
||||
|
||||
Pretrained models, training logs, decoding logs, and decoding results
|
||||
are available at
|
||||
<https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18>
|
||||
|
||||
|
||||
#### LibriSpeech BPE training results (Pruned Stateless Conv-Emformer RNN-T 2)
|
||||
|
||||
[conv_emformer_transducer_stateless2](./conv_emformer_transducer_stateless2)
|
||||
|
||||
It implements [Emformer](https://arxiv.org/abs/2010.10759) augmented with convolution module and simplified memory bank for streaming ASR.
|
||||
It is modified from [torchaudio](https://github.com/pytorch/audio).
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/440> for more details.
|
||||
|
||||
#### With lower latency setup, training on full librispeech
|
||||
|
||||
In this model, the lengths of chunk and right context are 32 frames (i.e., 0.32s) and 8 frames (i.e., 0.08s), respectively.
|
||||
|
||||
The WERs are:
|
||||
|
||||
| | test-clean | test-other | comment | decoding mode |
|
||||
|-------------------------------------|------------|------------|----------------------|----------------------|
|
||||
| greedy search (max sym per frame 1) | 3.5 | 9.09 | --epoch 30 --avg 10 | simulated streaming |
|
||||
| greedy search (max sym per frame 1) | 3.57 | 9.1 | --epoch 30 --avg 10 | streaming |
|
||||
| fast beam search | 3.5 | 8.91 | --epoch 30 --avg 10 | simulated streaming |
|
||||
| fast beam search | 3.54 | 8.91 | --epoch 30 --avg 10 | streaming |
|
||||
| modified beam search | 3.43 | 8.86 | --epoch 30 --avg 10 | simulated streaming |
|
||||
| modified beam search | 3.48 | 8.88 | --epoch 30 --avg 10 | streaming |
|
||||
|
||||
The training command is:
|
||||
|
||||
```bash
|
||||
./conv_emformer_transducer_stateless2/train.py \
|
||||
--world-size 6 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 280 \
|
||||
--master-port 12321 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 32 \
|
||||
--right-context-length 8 \
|
||||
--memory-size 32
|
||||
```
|
||||
|
||||
The tensorboard log can be found at
|
||||
<https://tensorboard.dev/experiment/W5MpxekiQLSPyM4fe5hbKg/>
|
||||
|
||||
The simulated streaming decoding command using greedy search is:
|
||||
```bash
|
||||
./conv_emformer_transducer_stateless2/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--max-duration 300 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 32 \
|
||||
--right-context-length 8 \
|
||||
--memory-size 32 \
|
||||
--decoding-method greedy_search \
|
||||
--use-averaged-model True
|
||||
```
|
||||
|
||||
The simulated streaming decoding command using fast beam search is:
|
||||
```bash
|
||||
./conv_emformer_transducer_stateless2/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--max-duration 300 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 32 \
|
||||
--right-context-length 8 \
|
||||
--memory-size 32 \
|
||||
--decoding-method fast_beam_search \
|
||||
--use-averaged-model True \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
```
|
||||
|
||||
The simulated streaming decoding command using modified beam search is:
|
||||
```bash
|
||||
./conv_emformer_transducer_stateless2/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--max-duration 300 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 32 \
|
||||
--right-context-length 8 \
|
||||
--memory-size 32 \
|
||||
--decoding-method modified_beam_search \
|
||||
--use-averaged-model True \
|
||||
--beam-size 4
|
||||
```
|
||||
|
||||
The streaming decoding command using greedy search is:
|
||||
```bash
|
||||
./conv_emformer_transducer_stateless2/streaming_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--num-decode-streams 2000 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 32 \
|
||||
--right-context-length 8 \
|
||||
--memory-size 32 \
|
||||
--decoding-method greedy_search \
|
||||
--use-averaged-model True
|
||||
```
|
||||
|
||||
The streaming decoding command using fast beam search is:
|
||||
```bash
|
||||
./conv_emformer_transducer_stateless2/streaming_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--num-decode-streams 2000 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 32 \
|
||||
--right-context-length 8 \
|
||||
--memory-size 32 \
|
||||
--decoding-method fast_beam_search \
|
||||
--use-averaged-model True \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
```
|
||||
|
||||
The streaming decoding command using modified beam search is:
|
||||
```bash
|
||||
./conv_emformer_transducer_stateless2/streaming_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--num-decode-streams 2000 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 32 \
|
||||
--right-context-length 8 \
|
||||
--memory-size 32 \
|
||||
--decoding-method modified_beam_search \
|
||||
--use-averaged-model True \
|
||||
--beam-size 4
|
||||
```
|
||||
|
||||
Pretrained models, training logs, decoding logs, and decoding results
|
||||
are available at
|
||||
<https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05>
|
||||
|
||||
#### With higher latency setup, training on full librispeech
|
||||
|
||||
In this model, the lengths of chunk and right context are 64 frames (i.e., 0.64s) and 16 frames (i.e., 0.16s), respectively.
|
||||
|
||||
The WERs are:
|
||||
|
||||
| | test-clean | test-other | comment | decoding mode |
|
||||
|-------------------------------------|------------|------------|----------------------|----------------------|
|
||||
| greedy search (max sym per frame 1) | 3.3 | 8.71 | --epoch 30 --avg 10 | simulated streaming |
|
||||
| greedy search (max sym per frame 1) | 3.35 | 8.65 | --epoch 30 --avg 10 | streaming |
|
||||
| fast beam search | 3.27 | 8.58 | --epoch 30 --avg 10 | simulated streaming |
|
||||
| fast beam search | 3.31 | 8.48 | --epoch 30 --avg 10 | streaming |
|
||||
| modified beam search | 3.26 | 8.56 | --epoch 30 --avg 10 | simulated streaming |
|
||||
| modified beam search | 3.29 | 8.47 | --epoch 30 --avg 10 | streaming |
|
||||
|
||||
The training command is:
|
||||
|
||||
```bash
|
||||
./conv_emformer_transducer_stateless2/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 280 \
|
||||
--master-port 12321 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 64 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 64 \
|
||||
--right-context-length 16 \
|
||||
--memory-size 32
|
||||
```
|
||||
|
||||
The tensorboard log can be found at
|
||||
<https://tensorboard.dev/experiment/eRx6XwbOQhGlywgD8lWBjw/>
|
||||
|
||||
The simulated streaming decoding command using greedy search is:
|
||||
```bash
|
||||
./conv_emformer_transducer_stateless2/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--max-duration 300 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 64 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 64 \
|
||||
--right-context-length 16 \
|
||||
--memory-size 32 \
|
||||
--decoding-method greedy_search \
|
||||
--use-averaged-model True
|
||||
```
|
||||
|
||||
The simulated streaming decoding command using fast beam search is:
|
||||
```bash
|
||||
./conv_emformer_transducer_stateless2/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--max-duration 300 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 64 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 64 \
|
||||
--right-context-length 16 \
|
||||
--memory-size 32 \
|
||||
--decoding-method fast_beam_search \
|
||||
--use-averaged-model True \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
```
|
||||
|
||||
The simulated streaming decoding command using modified beam search is:
|
||||
```bash
|
||||
./conv_emformer_transducer_stateless2/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--max-duration 300 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 64 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 64 \
|
||||
--right-context-length 16 \
|
||||
--memory-size 32 \
|
||||
--decoding-method modified_beam_search \
|
||||
--use-averaged-model True \
|
||||
--beam-size 4
|
||||
```
|
||||
|
||||
The streaming decoding command using greedy search is:
|
||||
```bash
|
||||
./conv_emformer_transducer_stateless2/streaming_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--num-decode-streams 2000 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 64 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 64 \
|
||||
--right-context-length 16 \
|
||||
--memory-size 32 \
|
||||
--decoding-method greedy_search \
|
||||
--use-averaged-model True
|
||||
```
|
||||
|
||||
The streaming decoding command using fast beam search is:
|
||||
```bash
|
||||
./conv_emformer_transducer_stateless2/streaming_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--num-decode-streams 2000 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 64 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 64 \
|
||||
--right-context-length 16 \
|
||||
--memory-size 32 \
|
||||
--decoding-method fast_beam_search \
|
||||
--use-averaged-model True \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
```
|
||||
|
||||
The streaming decoding command using modified beam search is:
|
||||
```bash
|
||||
./conv_emformer_transducer_stateless2/streaming_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--num-decode-streams 2000 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 64 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 64 \
|
||||
--right-context-length 16 \
|
||||
--memory-size 32 \
|
||||
--decoding-method modified_beam_search \
|
||||
--use-averaged-model True \
|
||||
--beam-size 4
|
||||
```
|
||||
|
||||
Pretrained models, training logs, decoding logs, and decoding results
|
||||
are available at
|
||||
<https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-larger-latency-2022-07-06>
|
||||
|
||||
|
||||
### LibriSpeech BPE training results (Pruned Stateless Streaming Conformer RNN-T)
|
||||
|
||||
#### [pruned_transducer_stateless](./pruned_transducer_stateless)
|
||||
@ -306,6 +704,80 @@ done
|
||||
|
||||
Pre-trained models, training and decoding logs, and decoding results are available at <https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless4_20220625>
|
||||
|
||||
#### [pruned_transducer_stateless5](./pruned_transducer_stateless5)
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/454> for more details.
|
||||
|
||||
##### Training on full librispeech
|
||||
The WERs are (the number in the table formatted as test-clean & test-other):
|
||||
|
||||
We only trained 25 epochs for saving time, if you want to get better results you can train more epochs.
|
||||
|
||||
| decoding method | left context | chunk size = 2 | chunk size = 4 | chunk size = 8 | chunk size = 16|
|
||||
|----------------------|--------------|----------------|----------------|----------------|----------------|
|
||||
| greedy search | 32 | 3.93 & 9.88 | 3.64 & 9.43 | 3.51 & 8.92 | 3.26 & 8.37 |
|
||||
| greedy search | 64 | 4.84 & 9.81 | 3.59 & 9.27 | 3.44 & 8.83 | 3.23 & 8.33 |
|
||||
| fast beam search | 32 | 3.86 & 9.77 | 3.67 & 9.3 | 3.5 & 8.83 | 3.27 & 8.33 |
|
||||
| fast beam search | 64 | 3.79 & 9.68 | 3.57 & 9.21 | 3.41 & 8.72 | 3.25 & 8.27 |
|
||||
| modified beam search | 32 | 3.84 & 9.71 | 3.66 & 9.38 | 3.47 & 8.86 | 3.26 & 8.42 |
|
||||
| modified beam search | 64 | 3.81 & 9.59 | 3.58 & 9.2 | 3.44 & 8.74 | 3.23 & 8.35 |
|
||||
|
||||
|
||||
**NOTE:** The WERs in table above were decoded with simulate streaming method (i.e. using masking strategy), see commands below. We also have [real streaming decoding](./pruned_transducer_stateless5/streaming_decode.py) script which should produce almost the same results. We tried adding right context in the real streaming decoding, but it seemed not to benefit the performance for all the models, the reasons might be the training and decoding mismatching.
|
||||
|
||||
The training command is:
|
||||
|
||||
```bash
|
||||
./pruned_transducer_stateless5/train.py \
|
||||
--exp-dir pruned_transducer_stateless5/exp \
|
||||
--num-encoder-layers 18 \
|
||||
--dim-feedforward 2048 \
|
||||
--nhead 8 \
|
||||
--encoder-dim 512 \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--full-libri 1 \
|
||||
--dynamic-chunk-training 1 \
|
||||
--causal-convolution 1 \
|
||||
--short-chunk-size 20 \
|
||||
--num-left-chunks 4 \
|
||||
--max-duration 300 \
|
||||
--world-size 4 \
|
||||
--start-epoch 1 \
|
||||
--num-epochs 25
|
||||
```
|
||||
|
||||
You can find the tensorboard log here <https://tensorboard.dev/experiment/rO04h6vjTLyw0qSxjp4m4Q>
|
||||
|
||||
The decoding command is:
|
||||
```bash
|
||||
decoding_method="greedy_search" # "fast_beam_search", "modified_beam_search"
|
||||
|
||||
for chunk in 2 4 8 16; do
|
||||
for left in 32 64; do
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--num-encoder-layers 18 \
|
||||
--dim-feedforward 2048 \
|
||||
--nhead 8 \
|
||||
--encoder-dim 512 \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--simulate-streaming 1 \
|
||||
--decode-chunk-size ${chunk} \
|
||||
--left-context ${left} \
|
||||
--causal-convolution 1 \
|
||||
--epoch 25 \
|
||||
--avg 3 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--max-sym-per-frame 1 \
|
||||
--max-duration 1000 \
|
||||
--decoding-method ${decoding_method}
|
||||
done
|
||||
done
|
||||
```
|
||||
|
||||
Pre-trained models, training and decoding logs, and decoding results are available at <https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless5_20220729>
|
||||
|
||||
|
||||
### LibriSpeech BPE training results (Pruned Stateless Conv-Emformer RNN-T)
|
||||
|
||||
@ -1686,6 +2158,118 @@ avg=11
|
||||
You can find the tensorboard log at: <https://tensorboard.dev/experiment/D7NQc3xqTpyVmWi5FnWjrA>
|
||||
|
||||
|
||||
### LibriSpeech BPE training results (Conformer-CTC 2)
|
||||
|
||||
#### [conformer_ctc2](./conformer_ctc2)
|
||||
|
||||
#### 2022-07-21
|
||||
|
||||
It implements a 'reworked' version of CTC attention model.
|
||||
As demenstrated by pruned_transducer_stateless2, reworked Conformer model has superior performance compared to the original Conformer.
|
||||
So in this modified version of CTC attention model, it has the reworked Conformer as the encoder and the reworked Transformer as the decoder.
|
||||
conformer_ctc2 also integrates with the idea of the 'averaging models' in pruned_transducer_stateless4.
|
||||
|
||||
The WERs on comparisons with a baseline model, for the librispeech test dataset, are listed as below.
|
||||
|
||||
The baseline model is the original conformer CTC attention model trained with icefall/egs/librispeech/ASR/conformer_ctc.
|
||||
The model is downloaded from <https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09>.
|
||||
This model has 12 layers of Conformer encoder layers and 6 Transformer decoder layers.
|
||||
Number of model parameters is 109,226,120.
|
||||
It has been trained with 90 epochs with full Librispeech dataset.
|
||||
|
||||
For this reworked CTC attention model, it has 12 layers of reworked Conformer encoder layers and 6 reworked Transformer decoder layers.
|
||||
Number of model parameters is 103,071,035.
|
||||
With full Librispeech data set, it was trained for **only** 30 epochs because the reworked model would converge much faster.
|
||||
Please refer to <https://tensorboard.dev/experiment/GR1s6VrJRTW5rtB50jakew/#scalars> to see the loss convergence curve.
|
||||
Please find the above trained model at <https://huggingface.co/WayneWiser/icefall-asr-librispeech-conformer-ctc2-jit-bpe-500-2022-07-21> in huggingface.
|
||||
|
||||
The decoding configuration for the reworked model is --epoch 30, --avg 8, --use-averaged-model True, which is the best after searching.
|
||||
|
||||
| WER | reworked ctc attention | with --epoch 30 --avg 8 --use-averaged-model True | | ctc attention| with --epoch 77 --avg 55 | |
|
||||
|------------------------|-------|------|------|------|------|-----|
|
||||
| test sets | test-clean | test-other | Avg | test-clean | test-other | Avg |
|
||||
| ctc-greedy-search | 2.98% | 7.14%| 5.06%| 2.90%| 7.47%| 5.19%|
|
||||
| ctc-decoding | 2.98% | 7.14%| 5.06%| 2.90%| 7.47%| 5.19%|
|
||||
| 1best | 2.93% | 6.37%| 4.65%| 2.70%| 6.49%| 4.60%|
|
||||
| nbest | 2.94% | 6.39%| 4.67%| 2.70%| 6.48%| 4.59%|
|
||||
| nbest-rescoring | 2.68% | 5.77%| 4.23%| 2.55%| 6.07%| 4.31%|
|
||||
| whole-lattice-rescoring| 2.66% | 5.76%| 4.21%| 2.56%| 6.04%| 4.30%|
|
||||
| attention-decoder | 2.59% | 5.54%| 4.07%| 2.41%| 5.77%| 4.09%|
|
||||
| nbest-oracle | 1.53% | 3.47%| 2.50%| 1.69%| 4.02%| 2.86%|
|
||||
| rnn-lm | 2.37% | 4.98%| 3.68%| 2.31%| 5.35%| 3.83%|
|
||||
|
||||
|
||||
|
||||
conformer_ctc2 also implements the CTC greedy search decoding, it has the identical WERs with the CTC-decoding method.
|
||||
For other decoding methods, the average WER of the two test sets with the two models is similar.
|
||||
Except for the 1best and nbest methods, the overall performance of reworked model is better than the baseline model.
|
||||
|
||||
|
||||
To reproduce the above result, use the following commands.
|
||||
|
||||
The training commands are:
|
||||
|
||||
```bash
|
||||
WORLD_SIZE=8
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
./conformer_ctc2/train.py \
|
||||
--manifest-dir data/fbank \
|
||||
--exp-dir conformer_ctc2/exp \
|
||||
--full-libri 1 \
|
||||
--spec-aug-time-warp-factor 80 \
|
||||
--max-duration 300 \
|
||||
--world-size ${WORLD_SIZE} \
|
||||
--start-epoch 1 \
|
||||
--num-epochs 30 \
|
||||
--att-rate 0.7 \
|
||||
--num-decoder-layers 6
|
||||
```
|
||||
|
||||
|
||||
And the following commands are for decoding:
|
||||
|
||||
```bash
|
||||
|
||||
|
||||
for method in ctc-greedy-search ctc-decoding 1best nbest-oracle; do
|
||||
python3 ./conformer_ctc2/decode.py \
|
||||
--exp-dir conformer_ctc2/exp \
|
||||
--use-averaged-model True --epoch 30 --avg 8 --max-duration 200 --method $method
|
||||
done
|
||||
|
||||
for method in nbest nbest-rescoring whole-lattice-rescoring attention-decoder ; do
|
||||
python3 ./conformer_ctc2/decode.py \
|
||||
--exp-dir conformer_ctc2/exp \
|
||||
--use-averaged-model True --epoch 30 --avg 8 --max-duration 20 --method $method
|
||||
done
|
||||
|
||||
rnn_dir=$(git rev-parse --show-toplevel)/icefall/rnn_lm
|
||||
./conformer_ctc2/decode.py \
|
||||
--exp-dir conformer_ctc2/exp \
|
||||
--lang-dir data/lang_bpe_500 \
|
||||
--lm-dir data/lm \
|
||||
--max-duration 30 \
|
||||
--concatenate-cuts 0 \
|
||||
--bucketing-sampler 1 \
|
||||
--num-paths 1000 \
|
||||
--use-averaged-model True \
|
||||
--epoch 30 \
|
||||
--avg 8 \
|
||||
--nbest-scale 0.5 \
|
||||
--rnn-lm-exp-dir ${rnn_dir}/exp \
|
||||
--rnn-lm-epoch 29 \
|
||||
--rnn-lm-avg 3 \
|
||||
--rnn-lm-embedding-dim 2048 \
|
||||
--rnn-lm-hidden-dim 2048 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--rnn-lm-tie-weights true \
|
||||
--method rnn-lm
|
||||
```
|
||||
|
||||
You can find the RNN-LM pre-trained model at
|
||||
<https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>
|
||||
|
||||
|
||||
### LibriSpeech BPE training results (Conformer-CTC)
|
||||
|
||||
#### 2021-11-09
|
||||
|
@ -525,6 +525,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -544,9 +545,9 @@ def decode_dataset(
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
else:
|
||||
@ -586,6 +587,7 @@ def save_results(
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
@ -779,6 +781,8 @@ def main():
|
||||
)
|
||||
rnn_lm_model.eval()
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
|
@ -447,6 +447,17 @@ def compute_loss(
|
||||
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
|
||||
info["utterances"] = feature.size(0)
|
||||
# averaged input duration in frames over utterances
|
||||
info["utt_duration"] = supervisions["num_frames"].sum().item()
|
||||
# averaged padding proportion over utterances
|
||||
info["utt_pad_proportion"] = (
|
||||
((feature.size(1) - supervisions["num_frames"]) / feature.size(1))
|
||||
.sum()
|
||||
.item()
|
||||
)
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
|
1
egs/librispeech/ASR/conformer_ctc2/__init__.py
Symbolic link
1
egs/librispeech/ASR/conformer_ctc2/__init__.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/__init__.py
|
1
egs/librispeech/ASR/conformer_ctc2/asr_datamodule.py
Symbolic link
1
egs/librispeech/ASR/conformer_ctc2/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/asr_datamodule.py
|
252
egs/librispeech/ASR/conformer_ctc2/attention.py
Normal file
252
egs/librispeech/ASR/conformer_ctc2/attention.py
Normal file
@ -0,0 +1,252 @@
|
||||
# Copyright 2022 Xiaomi Corp. (author: Quandong Wang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn.init import xavier_normal_
|
||||
|
||||
from scaling import ScaledLinear
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
r"""Allows the model to jointly attend to information
|
||||
from different representation subspaces.
|
||||
See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
||||
|
||||
.. math::
|
||||
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
||||
|
||||
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
||||
|
||||
Args:
|
||||
embed_dim: Total dimension of the model.
|
||||
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
|
||||
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
|
||||
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
|
||||
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
|
||||
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
|
||||
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
|
||||
Default: ``False``.
|
||||
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
|
||||
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
|
||||
batch_first: If ``True``, then the input and output tensors are provided
|
||||
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
||||
|
||||
Examples::
|
||||
|
||||
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||||
"""
|
||||
__constants__ = ["batch_first"]
|
||||
bias_k: Optional[torch.Tensor]
|
||||
bias_v: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False,
|
||||
kdim=None,
|
||||
vdim=None,
|
||||
batch_first=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super(MultiheadAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
self.vdim = vdim if vdim is not None else embed_dim
|
||||
self._qkv_same_embed_dim = (
|
||||
self.kdim == embed_dim and self.vdim == embed_dim
|
||||
)
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.batch_first = batch_first
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
if self._qkv_same_embed_dim is False:
|
||||
self.q_proj_weight = ScaledLinear(embed_dim, embed_dim, bias=bias)
|
||||
self.k_proj_weight = ScaledLinear(self.kdim, embed_dim, bias=bias)
|
||||
self.v_proj_weight = ScaledLinear(self.vdim, embed_dim, bias=bias)
|
||||
self.register_parameter("in_proj_weight", None)
|
||||
else:
|
||||
self.in_proj_weight = ScaledLinear(
|
||||
embed_dim, 3 * embed_dim, bias=bias
|
||||
)
|
||||
self.register_parameter("q_proj_weight", None)
|
||||
self.register_parameter("k_proj_weight", None)
|
||||
self.register_parameter("v_proj_weight", None)
|
||||
|
||||
if not bias:
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
|
||||
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = nn.Parameter(
|
||||
torch.empty((1, 1, embed_dim), **factory_kwargs)
|
||||
)
|
||||
self.bias_v = nn.Parameter(
|
||||
torch.empty((1, 1, embed_dim), **factory_kwargs)
|
||||
)
|
||||
else:
|
||||
self.bias_k = self.bias_v = None
|
||||
|
||||
self.add_zero_attn = add_zero_attn
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self):
|
||||
if self.bias_k is not None:
|
||||
xavier_normal_(self.bias_k)
|
||||
if self.bias_v is not None:
|
||||
xavier_normal_(self.bias_v)
|
||||
|
||||
def __setstate__(self, state):
|
||||
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
||||
if "_qkv_same_embed_dim" not in state:
|
||||
state["_qkv_same_embed_dim"] = True
|
||||
|
||||
super(MultiheadAttention, self).__setstate__(state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
query: Query embeddings of shape :math:`(L, N, E_q)` when ``batch_first=False`` or :math:`(N, L, E_q)`
|
||||
when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is the batch size,
|
||||
and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against
|
||||
key-value pairs to produce the output. See "Attention Is All You Need" for more details.
|
||||
key: Key embeddings of shape :math:`(S, N, E_k)` when ``batch_first=False`` or :math:`(N, S, E_k)` when
|
||||
``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and
|
||||
:math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details.
|
||||
value: Value embeddings of shape :math:`(S, N, E_v)` when ``batch_first=False`` or :math:`(N, S, E_v)` when
|
||||
``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and
|
||||
:math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details.
|
||||
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
|
||||
to ignore for the purpose of attention (i.e. treat as "padding"). Binary and byte masks are supported.
|
||||
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
|
||||
the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding ``key``
|
||||
value will be ignored.
|
||||
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
|
||||
Default: ``True``.
|
||||
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
|
||||
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
|
||||
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
|
||||
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
|
||||
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
|
||||
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
|
||||
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
|
||||
the attention weight.
|
||||
|
||||
Outputs:
|
||||
- **attn_output** - Attention outputs of shape :math:`(L, N, E)` when ``batch_first=False`` or
|
||||
:math:`(N, L, E)` when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is
|
||||
the batch size, and :math:`E` is the embedding dimension ``embed_dim``.
|
||||
- **attn_output_weights** - Attention output weights of shape :math:`(N, L, S)`, where :math:`N` is the batch
|
||||
size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length. Only returned
|
||||
when ``need_weights=True``.
|
||||
"""
|
||||
if self.batch_first:
|
||||
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
|
||||
|
||||
if not self._qkv_same_embed_dim:
|
||||
q_proj_weight = (
|
||||
self.q_proj_weight.get_weight()
|
||||
if self.q_proj_weight is not None
|
||||
else None
|
||||
)
|
||||
k_proj_weight = (
|
||||
self.k_proj_weight.get_weight()
|
||||
if self.k_proj_weight is not None
|
||||
else None
|
||||
)
|
||||
v_proj_weight = (
|
||||
self.v_proj_weight.get_weight()
|
||||
if self.v_proj_weight is not None
|
||||
else None
|
||||
)
|
||||
(
|
||||
attn_output,
|
||||
attn_output_weights,
|
||||
) = nn.functional.multi_head_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.in_proj_weight.get_weight(),
|
||||
self.in_proj_weight.get_bias(),
|
||||
self.bias_k,
|
||||
self.bias_v,
|
||||
self.add_zero_attn,
|
||||
self.dropout,
|
||||
self.out_proj.get_weight(),
|
||||
self.out_proj.get_bias(),
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
use_separate_proj_weight=True,
|
||||
q_proj_weight=q_proj_weight,
|
||||
k_proj_weight=k_proj_weight,
|
||||
v_proj_weight=v_proj_weight,
|
||||
)
|
||||
else:
|
||||
(
|
||||
attn_output,
|
||||
attn_output_weights,
|
||||
) = nn.functional.multi_head_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.in_proj_weight.get_weight(),
|
||||
self.in_proj_weight.get_bias(),
|
||||
self.bias_k,
|
||||
self.bias_v,
|
||||
self.add_zero_attn,
|
||||
self.dropout,
|
||||
self.out_proj.get_weight(),
|
||||
self.out_proj.get_bias(),
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
if self.batch_first:
|
||||
return attn_output.transpose(1, 0), attn_output_weights
|
||||
else:
|
||||
return attn_output, attn_output_weights
|
964
egs/librispeech/ASR/conformer_ctc2/conformer.py
Normal file
964
egs/librispeech/ASR/conformer_ctc2/conformer.py
Normal file
@ -0,0 +1,964 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||
# 2022 Xiaomi Corp. (author: Quandong Wang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
DoubleSwish,
|
||||
ScaledConv1d,
|
||||
ScaledLinear,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
from subsampling import Conv2dSubsampling
|
||||
|
||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
|
||||
|
||||
class Conformer(Transformer):
|
||||
"""
|
||||
Args:
|
||||
num_features (int): Number of input features
|
||||
num_classes (int): Number of output classes
|
||||
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
||||
d_model (int): attention dimension, also the output dimension
|
||||
nhead (int): number of head
|
||||
dim_feedforward (int): feedforward dimention
|
||||
num_encoder_layers (int): number of encoder layers
|
||||
num_decoder_layers (int): number of decoder layers
|
||||
dropout (float): dropout rate
|
||||
layer_dropout (float): layer-dropout rate.
|
||||
cnn_module_kernel (int): Kernel size of convolution module
|
||||
vgg_frontend (bool): whether to use vgg frontend.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
num_classes: int,
|
||||
subsampling_factor: int = 4,
|
||||
d_model: int = 256,
|
||||
nhead: int = 4,
|
||||
dim_feedforward: int = 2048,
|
||||
num_encoder_layers: int = 12,
|
||||
num_decoder_layers: int = 6,
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
cnn_module_kernel: int = 31,
|
||||
) -> None:
|
||||
super(Conformer, self).__init__(
|
||||
num_features=num_features,
|
||||
num_classes=num_classes,
|
||||
subsampling_factor=subsampling_factor,
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_feedforward,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
num_decoder_layers=num_decoder_layers,
|
||||
dropout=dropout,
|
||||
layer_dropout=layer_dropout,
|
||||
)
|
||||
|
||||
self.num_features = num_features
|
||||
self.subsampling_factor = subsampling_factor
|
||||
if subsampling_factor != 4:
|
||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||
|
||||
# self.encoder_embed converts the input of shape (N, T, num_features)
|
||||
# to the shape (N, T//subsampling_factor, d_model).
|
||||
# That is, it does two things simultaneously:
|
||||
# (1) subsampling: T -> T//subsampling_factor
|
||||
# (2) embedding: num_features -> d_model
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
encoder_layer = ConformerEncoderLayer(
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward,
|
||||
dropout,
|
||||
layer_dropout,
|
||||
cnn_module_kernel,
|
||||
)
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||
|
||||
def run_encoder(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
supervisions: Optional[Supervisions] = None,
|
||||
warmup: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||
supervisions:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
CAUTION: It contains length information, i.e., start and number of
|
||||
frames, before subsampling
|
||||
It is read directly from the batch, without any sorting. It is used
|
||||
to compute encoder padding mask, which is used as memory key padding
|
||||
mask for the decoder.
|
||||
warmup:
|
||||
A floating point value that gradually increases from 0 throughout
|
||||
training; when it is >= 1.0 we are "fully warmed up". It is used
|
||||
to turn modules on sequentially.
|
||||
Returns:
|
||||
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
|
||||
Tensor: Mask tensor of dimension (batch_size, input_length)
|
||||
"""
|
||||
x = self.encoder_embed(x)
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
mask = encoder_padding_mask(x.size(0), supervisions)
|
||||
if mask is not None:
|
||||
mask = mask.to(x.device)
|
||||
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
|
||||
x = self.encoder(
|
||||
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
||||
) # (T, N, C)
|
||||
|
||||
# x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
# return x, lengths
|
||||
return x, mask
|
||||
|
||||
|
||||
class ConformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
|
||||
See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
|
||||
|
||||
Args:
|
||||
d_model: the number of expected features in the input (required).
|
||||
nhead: the number of heads in the multiheadattention models (required).
|
||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||
dropout: the dropout value (default=0.1).
|
||||
cnn_module_kernel (int): Kernel size of convolution module.
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> pos_emb = torch.rand(32, 19, 512)
|
||||
>>> out = encoder_layer(src, pos_emb)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
nhead: int,
|
||||
dim_feedforward: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
cnn_module_kernel: int = 31,
|
||||
) -> None:
|
||||
super(ConformerEncoderLayer, self).__init__()
|
||||
|
||||
self.layer_dropout = layer_dropout
|
||||
|
||||
self.d_model = d_model
|
||||
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
d_model, nhead, dropout=0.0
|
||||
)
|
||||
|
||||
self.feed_forward = nn.Sequential(
|
||||
ScaledLinear(d_model, dim_feedforward),
|
||||
ActivationBalancer(channel_dim=-1),
|
||||
DoubleSwish(),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||
)
|
||||
|
||||
self.feed_forward_macaron = nn.Sequential(
|
||||
ScaledLinear(d_model, dim_feedforward),
|
||||
ActivationBalancer(channel_dim=-1),
|
||||
DoubleSwish(),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||
)
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||
|
||||
self.norm_final = BasicNorm(d_model)
|
||||
|
||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||
self.balancer = ActivationBalancer(
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder layer (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
src_mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||
bypass layers more frequently.
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*S-1, E)
|
||||
src_mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, N is the batch size, E is the feature number
|
||||
"""
|
||||
src_orig = src
|
||||
|
||||
warmup_scale = min(0.1 + warmup, 1.0)
|
||||
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
|
||||
# completely bypass it.
|
||||
if self.training:
|
||||
alpha = (
|
||||
warmup_scale
|
||||
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
|
||||
else 0.1
|
||||
)
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
# macaron style feed forward module
|
||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||
|
||||
# multi-headed self-attention module
|
||||
src_att = self.self_attn(
|
||||
src,
|
||||
src,
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)[0]
|
||||
src = src + self.dropout(src_att)
|
||||
|
||||
# convolution module
|
||||
src = src + self.dropout(self.conv_module(src))
|
||||
|
||||
# feed forward module
|
||||
src = src + self.dropout(self.feed_forward(src))
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
if alpha != 1.0:
|
||||
src = alpha * src + (1 - alpha) * src_orig
|
||||
|
||||
return src
|
||||
|
||||
|
||||
class ConformerEncoder(nn.Module):
|
||||
r"""ConformerEncoder is a stack of N encoder layers
|
||||
|
||||
Args:
|
||||
encoder_layer: an instance of the ConformerEncoderLayer() class (required).
|
||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||
>>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> pos_emb = torch.rand(32, 19, 512)
|
||||
>>> out = conformer_encoder(src, pos_emb)
|
||||
"""
|
||||
|
||||
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
) -> Tensor:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*S-1, E)
|
||||
mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||
|
||||
"""
|
||||
output = src
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
output = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class RelPositionalEncoding(torch.nn.Module):
|
||||
"""Relative positional encoding module.
|
||||
|
||||
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
|
||||
|
||||
Args:
|
||||
d_model: Embedding dimension.
|
||||
dropout_rate: Dropout rate.
|
||||
max_len: Maximum input length.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, d_model: int, dropout_rate: float, max_len: int = 5000
|
||||
) -> None:
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
def extend_pe(self, x: Tensor) -> None:
|
||||
"""Reset the positional encodings."""
|
||||
if self.pe is not None:
|
||||
# self.pe contains both positive and negative parts
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||
x.device
|
||||
):
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
# Suppose `i` means to the position of query vecotr and `j` means the
|
||||
# position of key vector. We use position relative positions when keys
|
||||
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||
pe_positive = torch.zeros(x.size(1), self.d_model)
|
||||
pe_negative = torch.zeros(x.size(1), self.d_model)
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
||||
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
||||
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
||||
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
||||
|
||||
# Reserve the order of positive indices and concat both positive and
|
||||
# negative indices. This is used to support the shifting trick
|
||||
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
||||
pe_negative = pe_negative[1:].unsqueeze(0)
|
||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
pos_emb = self.pe[
|
||||
:,
|
||||
self.pe.size(1) // 2
|
||||
- x.size(1)
|
||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||
+ x.size(1),
|
||||
]
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
|
||||
class RelPositionMultiheadAttention(nn.Module):
|
||||
r"""Multi-Head Attention layer with relative position encoding
|
||||
|
||||
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||
|
||||
Args:
|
||||
embed_dim: total dimension of the model.
|
||||
num_heads: parallel attention heads.
|
||||
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
|
||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
dropout: float = 0.0,
|
||||
) -> None:
|
||||
super(RelPositionMultiheadAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
|
||||
self.out_proj = ScaledLinear(
|
||||
embed_dim, embed_dim, bias=True, initial_scale=0.25
|
||||
)
|
||||
|
||||
# linear transformation for positional encoding.
|
||||
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
|
||||
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
|
||||
self._reset_parameters()
|
||||
|
||||
def _pos_bias_u(self):
|
||||
return self.pos_bias_u * self.pos_bias_u_scale.exp()
|
||||
|
||||
def _pos_bias_v(self):
|
||||
return self.pos_bias_v * self.pos_bias_v_scale.exp()
|
||||
|
||||
def _reset_parameters(self) -> None:
|
||||
nn.init.normal_(self.pos_bias_u, std=0.01)
|
||||
nn.init.normal_(self.pos_bias_v, std=0.01)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
pos_emb: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
pos_emb: Positional embedding tensor
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention. When given a binary mask and a value is True,
|
||||
the corresponding value on the attention layer will be ignored. When given
|
||||
a byte mask and a value is non-zero, the corresponding value on the attention
|
||||
layer will be ignored
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
|
||||
Shape:
|
||||
- Inputs:
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
||||
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
||||
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
||||
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||
is provided, it will be added to the attention weight.
|
||||
|
||||
- Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||
L is the target sequence length, S is the source sequence length.
|
||||
"""
|
||||
return self.multi_head_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
pos_emb,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.in_proj.get_weight(),
|
||||
self.in_proj.get_bias(),
|
||||
self.dropout,
|
||||
self.out_proj.get_weight(),
|
||||
self.out_proj.get_bias(),
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
|
||||
def rel_shift(self, x: Tensor) -> Tensor:
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
x: Input tensor (batch, head, time1, 2*time1-1).
|
||||
time1 means the length of query vector.
|
||||
|
||||
Returns:
|
||||
Tensor: tensor of shape (batch, head, time1, time2)
|
||||
(note: time2 has the same value as time1, but it is for
|
||||
the key, while time1 is for the query).
|
||||
"""
|
||||
(batch_size, num_heads, time1, n) = x.shape
|
||||
assert n == 2 * time1 - 1
|
||||
# Note: TorchScript requires explicit arg for stride()
|
||||
batch_stride = x.stride(0)
|
||||
head_stride = x.stride(1)
|
||||
time1_stride = x.stride(2)
|
||||
n_stride = x.stride(3)
|
||||
return x.as_strided(
|
||||
(batch_size, num_heads, time1, time1),
|
||||
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
||||
storage_offset=n_stride * (time1 - 1),
|
||||
)
|
||||
|
||||
def multi_head_attention_forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
pos_emb: Tensor,
|
||||
embed_dim_to_check: int,
|
||||
num_heads: int,
|
||||
in_proj_weight: Tensor,
|
||||
in_proj_bias: Tensor,
|
||||
dropout_p: float,
|
||||
out_proj_weight: Tensor,
|
||||
out_proj_bias: Tensor,
|
||||
training: bool = True,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
pos_emb: Positional embedding tensor
|
||||
embed_dim_to_check: total dimension of the model.
|
||||
num_heads: parallel attention heads.
|
||||
in_proj_weight, in_proj_bias: input projection weight and bias.
|
||||
dropout_p: probability of an element to be zeroed.
|
||||
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
||||
training: apply dropout if is ``True``.
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention. This is an binary mask. When the value is True,
|
||||
the corresponding value on the attention layer will be filled with -inf.
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
|
||||
Shape:
|
||||
Inputs:
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
|
||||
length, N is the batch size, E is the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
||||
will be unchanged. If a BoolTensor is provided, the positions with the
|
||||
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
||||
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||
is provided, it will be added to the attention weight.
|
||||
|
||||
Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||
L is the target sequence length, S is the source sequence length.
|
||||
"""
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
assert embed_dim == embed_dim_to_check
|
||||
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
||||
|
||||
head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
head_dim * num_heads == embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
scaling = float(head_dim) ** -0.5
|
||||
|
||||
if torch.equal(query, key) and torch.equal(key, value):
|
||||
# self-attention
|
||||
q, k, v = nn.functional.linear(
|
||||
query, in_proj_weight, in_proj_bias
|
||||
).chunk(3, dim=-1)
|
||||
|
||||
elif torch.equal(key, value):
|
||||
# encoder-decoder attention
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = 0
|
||||
_end = embed_dim
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
q = nn.functional.linear(query, _w, _b)
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim
|
||||
_end = None
|
||||
_w = in_proj_weight[_start:, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:]
|
||||
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
|
||||
|
||||
else:
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = 0
|
||||
_end = embed_dim
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
q = nn.functional.linear(query, _w, _b)
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim
|
||||
_end = embed_dim * 2
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
k = nn.functional.linear(key, _w, _b)
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim * 2
|
||||
_end = None
|
||||
_w = in_proj_weight[_start:, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:]
|
||||
v = nn.functional.linear(value, _w, _b)
|
||||
|
||||
if attn_mask is not None:
|
||||
assert (
|
||||
attn_mask.dtype == torch.float32
|
||||
or attn_mask.dtype == torch.float64
|
||||
or attn_mask.dtype == torch.float16
|
||||
or attn_mask.dtype == torch.uint8
|
||||
or attn_mask.dtype == torch.bool
|
||||
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
|
||||
attn_mask.dtype
|
||||
)
|
||||
if attn_mask.dtype == torch.uint8:
|
||||
warnings.warn(
|
||||
"Byte tensor for attn_mask is deprecated. Use bool tensor instead."
|
||||
)
|
||||
attn_mask = attn_mask.to(torch.bool)
|
||||
|
||||
if attn_mask.dim() == 2:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
||||
raise RuntimeError(
|
||||
"The size of the 2D attn_mask is not correct."
|
||||
)
|
||||
elif attn_mask.dim() == 3:
|
||||
if list(attn_mask.size()) != [
|
||||
bsz * num_heads,
|
||||
query.size(0),
|
||||
key.size(0),
|
||||
]:
|
||||
raise RuntimeError(
|
||||
"The size of the 3D attn_mask is not correct."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"attn_mask's dimension {} is not supported".format(
|
||||
attn_mask.dim()
|
||||
)
|
||||
)
|
||||
# attn_mask's dim is 3 now.
|
||||
|
||||
# convert ByteTensor key_padding_mask to bool
|
||||
if (
|
||||
key_padding_mask is not None
|
||||
and key_padding_mask.dtype == torch.uint8
|
||||
):
|
||||
warnings.warn(
|
||||
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
|
||||
)
|
||||
key_padding_mask = key_padding_mask.to(torch.bool)
|
||||
|
||||
q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim)
|
||||
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
|
||||
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
src_len = k.size(0)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
|
||||
key_padding_mask.size(0), bsz
|
||||
)
|
||||
assert key_padding_mask.size(1) == src_len, "{} == {}".format(
|
||||
key_padding_mask.size(1), src_len
|
||||
)
|
||||
|
||||
q = q.transpose(0, 1) # (batch, time1, head, d_k)
|
||||
|
||||
pos_emb_bsz = pos_emb.size(0)
|
||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
||||
|
||||
q_with_bias_u = (q + self._pos_bias_u()).transpose(
|
||||
1, 2
|
||||
) # (batch, head, time1, d_k)
|
||||
|
||||
q_with_bias_v = (q + self._pos_bias_v()).transpose(
|
||||
1, 2
|
||||
) # (batch, head, time1, d_k)
|
||||
|
||||
# compute attention score
|
||||
# first compute matrix a and matrix c
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
||||
matrix_ac = torch.matmul(
|
||||
q_with_bias_u, k
|
||||
) # (batch, head, time1, time2)
|
||||
|
||||
# compute matrix b and matrix d
|
||||
matrix_bd = torch.matmul(
|
||||
q_with_bias_v, p.transpose(-2, -1)
|
||||
) # (batch, head, time1, 2*time1-1)
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
attn_output_weights = (
|
||||
matrix_ac + matrix_bd
|
||||
) # (batch, head, time1, time2)
|
||||
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, -1
|
||||
)
|
||||
|
||||
assert list(attn_output_weights.size()) == [
|
||||
bsz * num_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
]
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
|
||||
else:
|
||||
attn_output_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz, num_heads, tgt_len, src_len
|
||||
)
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||
float("-inf"),
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, src_len
|
||||
)
|
||||
|
||||
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
||||
attn_output_weights = nn.functional.dropout(
|
||||
attn_output_weights, p=dropout_p, training=training
|
||||
)
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(tgt_len, bsz, embed_dim)
|
||||
)
|
||||
attn_output = nn.functional.linear(
|
||||
attn_output, out_proj_weight, out_proj_bias
|
||||
)
|
||||
|
||||
if need_weights:
|
||||
# average attention weights over heads
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz, num_heads, tgt_len, src_len
|
||||
)
|
||||
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
||||
else:
|
||||
return attn_output, None
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Conformer model.
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
||||
|
||||
Args:
|
||||
channels (int): The number of channels of conv layers.
|
||||
kernel_size (int): Kernerl size of conv layers.
|
||||
bias (bool): Whether to use bias in conv layers (default=True).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, channels: int, kernel_size: int, bias: bool = True
|
||||
) -> None:
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
|
||||
self.pointwise_conv1 = ScaledConv1d(
|
||||
channels,
|
||||
2 * channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
# after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu).
|
||||
# For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
|
||||
# but sometimes, for some reason, for layer 0 the rms ends up being very large,
|
||||
# between 50 and 100 for different channels. This will cause very peaky and
|
||||
# sparse derivatives for the sigmoid gating function, which will tend to make
|
||||
# the loss function not learn effectively. (for most layers the average absolute values
|
||||
# are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
|
||||
# at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
|
||||
# layers, which likely breaks down as 0.5 for the "linear" half and
|
||||
# 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we
|
||||
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
||||
# it will be in a better position to start learning something, i.e. to latch onto
|
||||
# the correct range.
|
||||
self.deriv_balancer1 = ActivationBalancer(
|
||||
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
|
||||
)
|
||||
|
||||
self.depthwise_conv = ScaledConv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.deriv_balancer2 = ActivationBalancer(
|
||||
channel_dim=1, min_positive=0.05, max_positive=1.0
|
||||
)
|
||||
|
||||
self.activation = DoubleSwish()
|
||||
|
||||
self.pointwise_conv2 = ScaledConv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
initial_scale=0.25,
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Compute convolution module.
|
||||
|
||||
Args:
|
||||
x: Input tensor (#time, batch, channels).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (#time, batch, channels).
|
||||
|
||||
"""
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||
|
||||
# GLU mechanism
|
||||
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
|
||||
|
||||
x = self.deriv_balancer1(x)
|
||||
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||
|
||||
# 1D Depthwise Conv
|
||||
x = self.depthwise_conv(x)
|
||||
|
||||
x = self.deriv_balancer2(x)
|
||||
x = self.activation(x)
|
||||
|
||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||
|
||||
return x.permute(2, 0, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
feature_dim = 50
|
||||
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
||||
batch_size = 5
|
||||
seq_len = 20
|
||||
# Just make sure the forward pass runs.
|
||||
f = c(
|
||||
torch.randn(batch_size, seq_len, feature_dim),
|
||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||
warmup=0.5,
|
||||
)
|
999
egs/librispeech/ASR/conformer_ctc2/decode.py
Executable file
999
egs/librispeech/ASR/conformer_ctc2/decode.py
Executable file
@ -0,0 +1,999 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
|
||||
# Fangjun Kuang,
|
||||
# Quandong Wang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_rnn_lm,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.rnn_lm.model import RnnLmModel
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
load_averaged_model,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=77,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="attention-decoder",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
||||
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||
It needs neither a lexicon nor an n-gram LM.
|
||||
- (1) ctc-greedy-search. It only use CTC output and a sentence piece
|
||||
model for decoding. It produces the same results with ctc-decoding.
|
||||
- (2) 1best. Extract the best path from the decoding lattice as the
|
||||
decoding result.
|
||||
- (3) nbest. Extract n paths from the decoding lattice; the path
|
||||
with the highest score is the decoding result.
|
||||
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
||||
the highest score is the decoding result.
|
||||
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
||||
is the decoding result.
|
||||
- (6) attention-decoder. Extract n paths from the LM rescored
|
||||
lattice, the path with the highest score is the decoding result.
|
||||
- (7) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume
|
||||
you have trained an RNN LM using ./rnn_lm/train.py
|
||||
- (8) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
rescoring method can achieve. Useful for debugging n-best
|
||||
rescoring method.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-decoder-layers",
|
||||
type=int,
|
||||
default=6,
|
||||
help="""Number of decoder layer of transformer decoder.
|
||||
Setting this to 0 will not create the decoder at all (pure CTC model)
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""Number of paths for n-best based decoding method.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""The scale to be applied to `lattice.scores`.
|
||||
It's needed if you use any kinds of n-best based rescoring.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
|
||||
A smaller value results in more unique paths.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="conformer_ctc2/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-dir",
|
||||
type=str,
|
||||
default="data/lm",
|
||||
help="""The n-gram LM dir.
|
||||
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-exp-dir",
|
||||
type=str,
|
||||
default="rnn_lm/exp",
|
||||
help="""Used only when --method is rnn-lm.
|
||||
It specifies the path to RNN LM exp dir.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-epoch",
|
||||
type=int,
|
||||
default=7,
|
||||
help="""Used only when --method is rnn-lm.
|
||||
It specifies the checkpoint to use.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-avg",
|
||||
type=int,
|
||||
default=2,
|
||||
help="""Used only when --method is rnn-lm.
|
||||
It specifies the number of checkpoints to average.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-embedding-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Embedding dim of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-hidden-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Hidden dim of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-num-layers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of RNN layers the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rnn-lm-tie-weights",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to share the weights between the input embedding layer and the
|
||||
last output linear layer
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
# parameters for conformer
|
||||
"subsampling_factor": 4,
|
||||
"feature_dim": 80,
|
||||
"nhead": 8,
|
||||
"dim_feedforward": 2048,
|
||||
"encoder_dim": 512,
|
||||
"num_encoder_layers": 12,
|
||||
# parameters for decoding
|
||||
"search_beam": 20,
|
||||
"output_beam": 8,
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def ctc_greedy_search(
|
||||
nnet_output: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
) -> List[List[int]]:
|
||||
"""Apply CTC greedy search
|
||||
|
||||
Args:
|
||||
speech (torch.Tensor): (batch, max_len, feat_dim)
|
||||
speech_length (torch.Tensor): (batch, )
|
||||
Returns:
|
||||
List[List[int]]: best path result
|
||||
"""
|
||||
batch_size = memory.shape[1]
|
||||
# Let's assume B = batch_size
|
||||
encoder_out = memory
|
||||
encoder_mask = memory_key_padding_mask
|
||||
maxlen = encoder_out.size(0)
|
||||
|
||||
ctc_probs = nnet_output # (B, maxlen, vocab_size)
|
||||
topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1)
|
||||
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
|
||||
topk_index = topk_index.masked_fill_(encoder_mask, 0) # (B, maxlen)
|
||||
hyps = [hyp.tolist() for hyp in topk_index]
|
||||
scores = topk_prob.max(1)
|
||||
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
|
||||
return hyps, scores
|
||||
|
||||
|
||||
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
|
||||
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
|
||||
new_hyp: List[int] = []
|
||||
cur = 0
|
||||
while cur < len(hyp):
|
||||
if hyp[cur] != 0:
|
||||
new_hyp.append(hyp[cur])
|
||||
prev = cur
|
||||
while cur < len(hyp) and hyp[cur] == hyp[prev]:
|
||||
cur += 1
|
||||
return new_hyp
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
rnn_lm_model: Optional[nn.Module],
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
batch: dict,
|
||||
word_table: k2.SymbolTable,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if no rescoring is used, the key is the string `no_rescore`.
|
||||
If LM rescoring is used, the key is the string `lm_scale_xxx`,
|
||||
where `xxx` is the value of `lm_scale`. An example key is
|
||||
`lm_scale_0.7`
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
|
||||
- params.method is "1best", it uses 1best decoding without LM rescoring.
|
||||
- params.method is "nbest", it uses nbest decoding without LM rescoring.
|
||||
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
|
||||
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
|
||||
rescoring.
|
||||
|
||||
model:
|
||||
The neural model.
|
||||
rnn_lm_model:
|
||||
The neural model for RNN LM.
|
||||
HLG:
|
||||
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.method is ctc-decoding.
|
||||
bpe_model:
|
||||
The BPE model. Used only when params.method is ctc-decoding.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
sos_id:
|
||||
The token ID of the SOS.
|
||||
eos_id:
|
||||
The token ID of the EOS.
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict. Note: If it decodes to nothing, then return None.
|
||||
"""
|
||||
if HLG is not None:
|
||||
device = HLG.device
|
||||
else:
|
||||
device = H.device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
|
||||
# nnet_output is (N, T, C)
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
torch.div(
|
||||
supervisions["start_frame"],
|
||||
params.subsampling_factor,
|
||||
rounding_mode="trunc",
|
||||
),
|
||||
torch.div(
|
||||
supervisions["num_frames"],
|
||||
params.subsampling_factor,
|
||||
rounding_mode="trunc",
|
||||
),
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
if H is None:
|
||||
assert HLG is not None
|
||||
decoding_graph = HLG
|
||||
else:
|
||||
assert HLG is None
|
||||
assert bpe_model is not None
|
||||
decoding_graph = H
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
decoding_graph=decoding_graph,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
if params.method == "ctc-decoding":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
# Note: `best_path.aux_labels` contains token IDs, not word IDs
|
||||
# since we are using H, not HLG here.
|
||||
#
|
||||
# token_ids is a lit-of-list of IDs
|
||||
token_ids = get_texts(best_path)
|
||||
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
key = "ctc-decoding"
|
||||
return {key: hyps}
|
||||
|
||||
if params.method == "ctc-greedy-search":
|
||||
hyps, _ = ctc_greedy_search(
|
||||
nnet_output,
|
||||
memory,
|
||||
memory_key_padding_mask,
|
||||
)
|
||||
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(hyps)
|
||||
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
key = "ctc-greedy-search"
|
||||
return {key: hyps}
|
||||
|
||||
if params.method == "nbest-oracle":
|
||||
# Note: You can also pass rescored lattices to it.
|
||||
# We choose the HLG decoded lattice for speed reasons
|
||||
# as HLG decoding is faster and the oracle WER
|
||||
# is only slightly worse than that of rescored lattices.
|
||||
best_path = nbest_oracle(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
word_table=word_table,
|
||||
nbest_scale=params.nbest_scale,
|
||||
oov="<UNK>",
|
||||
)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
|
||||
return {key: hyps}
|
||||
|
||||
if params.method in ["1best", "nbest"]:
|
||||
if params.method == "1best":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
key = "no_rescore"
|
||||
else:
|
||||
best_path = nbest_decoding(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
use_double_scores=params.use_double_scores,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
return {key: hyps}
|
||||
|
||||
assert params.method in [
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
"rnn-lm",
|
||||
]
|
||||
|
||||
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||
|
||||
if params.method == "nbest-rescoring":
|
||||
best_path_dict = rescore_with_n_best_list(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=lm_scale_list,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
elif params.method == "whole-lattice-rescoring":
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
elif params.method == "attention-decoder":
|
||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||
rescored_lattice = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=None,
|
||||
)
|
||||
# TODO: pass `lattice` instead of `rescored_lattice` to
|
||||
# `rescore_with_attention_decoder`
|
||||
|
||||
best_path_dict = rescore_with_attention_decoder(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
elif params.method == "rnn-lm":
|
||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||
rescored_lattice = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=None,
|
||||
)
|
||||
|
||||
best_path_dict = rescore_with_rnn_lm(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
rnn_lm_model=rnn_lm_model,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
blank_id=0,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.method}"
|
||||
|
||||
ans = dict()
|
||||
if best_path_dict is not None:
|
||||
for lm_scale_str, best_path in best_path_dict.items():
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
ans[lm_scale_str] = hyps
|
||||
else:
|
||||
ans = None
|
||||
return ans
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
rnn_lm_model: Optional[nn.Module],
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
word_table: k2.SymbolTable,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
rnn_lm_model:
|
||||
The neural model for RNN LM.
|
||||
HLG:
|
||||
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.method is ctc-decoding.
|
||||
bpe_model:
|
||||
The BPE model. Used only when params.method is ctc-decoding.
|
||||
word_table:
|
||||
It is the word symbol table.
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
The token ID for EOS.
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
||||
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
rnn_lm_model=rnn_lm_model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
batch=batch,
|
||||
word_table=word_table,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
if hyps_dict is not None:
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
else:
|
||||
assert (
|
||||
len(results) > 0
|
||||
), "It should not decode to empty in the first batch!"
|
||||
this_batch = []
|
||||
hyp_words = []
|
||||
for ref_text in texts:
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
|
||||
for lm_scale in results.keys():
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
):
|
||||
if params.method in ("attention-decoder", "rnn-lm"):
|
||||
# Set it to False since there are too many logs.
|
||||
enable_log = False
|
||||
else:
|
||||
enable_log = True
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=enable_log
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
if enable_log:
|
||||
logging.info(
|
||||
"Wrote detailed error stats to {}".format(errs_filename)
|
||||
)
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
args.lm_dir = Path(args.lm_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||
params.lang_dir,
|
||||
device=device,
|
||||
sos_token="<sos/eos>",
|
||||
eos_token="<sos/eos>",
|
||||
)
|
||||
sos_id = graph_compiler.sos_id
|
||||
eos_id = graph_compiler.eos_id
|
||||
|
||||
params.num_classes = num_classes
|
||||
params.sos_id = sos_id
|
||||
params.eos_id = eos_id
|
||||
|
||||
if params.method == "ctc-decoding" or params.method == "ctc-greedy-search":
|
||||
HLG = None
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=False,
|
||||
device=device,
|
||||
)
|
||||
bpe_model = spm.SentencePieceProcessor()
|
||||
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
||||
else:
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||
)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
if params.method in (
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
"rnn-lm",
|
||||
):
|
||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||
logging.info("Loading G_4_gram.fst.txt")
|
||||
logging.warning("It may take 8 minutes.")
|
||||
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
# G.aux_labels is not needed in later computations, so
|
||||
# remove it here.
|
||||
del G.aux_labels
|
||||
# CAUTION: The following line is crucial.
|
||||
# Arcs entering the back-off state have label equal to #0.
|
||||
# We have to change it to 0 here.
|
||||
G.labels[G.labels >= first_word_disambig_id] = 0
|
||||
# See https://github.com/k2-fsa/k2/issues/874
|
||||
# for why we need to set G.properties to None
|
||||
G.__dict__["_properties"] = None
|
||||
G = k2.Fsa.from_fsas([G]).to(device)
|
||||
G = k2.arc_sort(G)
|
||||
# Save a dummy value so that it can be loaded in C++.
|
||||
# See https://github.com/pytorch/pytorch/issues/67902
|
||||
# for why we need to do this.
|
||||
G.dummy = 1
|
||||
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
|
||||
if params.method in [
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
"rnn-lm",
|
||||
]:
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
G = G.to(device)
|
||||
|
||||
# G.lm_scores is used to replace HLG.lm_scores during
|
||||
# LM rescoring.
|
||||
G.lm_scores = G.scores.clone()
|
||||
else:
|
||||
G = None
|
||||
|
||||
model = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
nhead=params.nhead,
|
||||
d_model=params.encoder_dim,
|
||||
num_classes=num_classes,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
rnn_lm_model = None
|
||||
if params.method == "rnn-lm":
|
||||
rnn_lm_model = RnnLmModel(
|
||||
vocab_size=params.num_classes,
|
||||
embedding_dim=params.rnn_lm_embedding_dim,
|
||||
hidden_dim=params.rnn_lm_hidden_dim,
|
||||
num_layers=params.rnn_lm_num_layers,
|
||||
tie_weights=params.rnn_lm_tie_weights,
|
||||
)
|
||||
if params.rnn_lm_avg == 1:
|
||||
load_checkpoint(
|
||||
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
|
||||
rnn_lm_model,
|
||||
)
|
||||
rnn_lm_model.to(device)
|
||||
else:
|
||||
rnn_lm_model = load_averaged_model(
|
||||
params.rnn_lm_exp_dir,
|
||||
rnn_lm_model,
|
||||
params.rnn_lm_epoch,
|
||||
params.rnn_lm_avg,
|
||||
device,
|
||||
)
|
||||
rnn_lm_model.eval()
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
rnn_lm_model=rnn_lm_model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
word_table=lexicon.word_table,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params, test_set_name=test_set, results_dict=results_dict
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
281
egs/librispeech/ASR/conformer_ctc2/export.py
Executable file
281
egs/librispeech/ASR/conformer_ctc2/export.py
Executable file
@ -0,0 +1,281 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Quandong Wang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
Usage:
|
||||
./conformer_ctc2/export.py \
|
||||
--exp-dir ./conformer_ctc2/exp \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
It will generate a file exp_dir/pretrained.pt
|
||||
|
||||
To use the generated file with `conformer_ctc2/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
./conformer_ctc2/decode.py \
|
||||
--exp-dir ./conformer_ctc2/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 100
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from decode import get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.utils import str2bool
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-decoder-layers",
|
||||
type=int,
|
||||
default=6,
|
||||
help="""Number of decoder layer of transformer decoder.
|
||||
Setting this to 0 will not create the decoder at all (pure CTC model)
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="conformer_ctc2/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
||||
model = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
nhead=params.nhead,
|
||||
d_model=params.encoder_dim,
|
||||
num_classes=num_classes,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.eval()
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torch.jit.script")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/librispeech/ASR/conformer_ctc2/label_smoothing.py
Symbolic link
1
egs/librispeech/ASR/conformer_ctc2/label_smoothing.py
Symbolic link
@ -0,0 +1 @@
|
||||
../conformer_ctc/label_smoothing.py
|
1
egs/librispeech/ASR/conformer_ctc2/optim.py
Symbolic link
1
egs/librispeech/ASR/conformer_ctc2/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/optim.py
|
1
egs/librispeech/ASR/conformer_ctc2/scaling.py
Symbolic link
1
egs/librispeech/ASR/conformer_ctc2/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/scaling.py
|
121
egs/librispeech/ASR/conformer_ctc2/subsampling.py
Normal file
121
egs/librispeech/ASR/conformer_ctc2/subsampling.py
Normal file
@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||
# 2022 Xiaomi Corporation (author: Quandong Wang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
DoubleSwish,
|
||||
ScaledConv2d,
|
||||
ScaledLinear,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
|
||||
Convert an input of shape (N, T, idim) to an output
|
||||
with shape (N, T', odim), where
|
||||
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
|
||||
|
||||
It is based on
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
layer1_channels: int = 8,
|
||||
layer2_channels: int = 32,
|
||||
layer3_channels: int = 128,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
in_channels:
|
||||
Number of channels in. The input shape is (N, T, in_channels).
|
||||
Caution: It requires: T >=7, in_channels >=7
|
||||
out_channels
|
||||
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
|
||||
layer1_channels:
|
||||
Number of channels in layer1
|
||||
layer1_channels:
|
||||
Number of channels in layer2
|
||||
"""
|
||||
assert in_channels >= 7
|
||||
super().__init__()
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
ScaledConv2d(
|
||||
in_channels=1,
|
||||
out_channels=layer1_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer1_channels,
|
||||
out_channels=layer2_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer2_channels,
|
||||
out_channels=layer3_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
)
|
||||
self.out = ScaledLinear(
|
||||
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
|
||||
)
|
||||
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
||||
# itself has learned scale, so the extra degree of freedom is not
|
||||
# needed.
|
||||
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||
# constrain median of output to be close to zero.
|
||||
self.out_balancer = ActivationBalancer(
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is (N, T, idim).
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
||||
"""
|
||||
# On entry, x is (N, T, idim)
|
||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
x = self.conv(x)
|
||||
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
x = self.out_norm(x)
|
||||
x = self.out_balancer(x)
|
||||
return x
|
1128
egs/librispeech/ASR/conformer_ctc2/train.py
Executable file
1128
egs/librispeech/ASR/conformer_ctc2/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1092
egs/librispeech/ASR/conformer_ctc2/transformer.py
Normal file
1092
egs/librispeech/ASR/conformer_ctc2/transformer.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -449,6 +449,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -466,9 +467,9 @@ def decode_dataset(
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
@ -496,6 +497,7 @@ def save_results(
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
@ -661,6 +663,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
# CAUTION: `test_sets` is for displaying only.
|
||||
# If you want to skip test-clean, you have to skip
|
||||
|
@ -277,10 +277,10 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
feature_lens += params.right_context_length
|
||||
feature_lens += params.chunk_length
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.right_context_length),
|
||||
pad=(0, 0, 0, params.chunk_length),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
@ -403,6 +403,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -415,9 +416,9 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -442,6 +443,7 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
@ -624,6 +626,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
|
@ -1141,8 +1141,8 @@ class EmformerEncoderLayer(nn.Module):
|
||||
- output utterance, with shape (U, B, D);
|
||||
- output right_context, with shape (R, B, D);
|
||||
- output memory, with shape (1, B, D) or (0, B, D).
|
||||
- output state.
|
||||
- updated conv_cache.
|
||||
- updated attention cache.
|
||||
- updated convolution cache.
|
||||
"""
|
||||
R = right_context.size(0)
|
||||
src = torch.cat([right_context, utterance])
|
||||
@ -1252,6 +1252,11 @@ class EmformerEncoder(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert (
|
||||
chunk_length - 1
|
||||
) & chunk_length == 0, "chunk_length should be a power of 2."
|
||||
self.shift = int(math.log(chunk_length, 2))
|
||||
|
||||
self.use_memory = memory_size > 0
|
||||
self.init_memory_op = nn.AvgPool1d(
|
||||
kernel_size=chunk_length,
|
||||
@ -1525,7 +1530,6 @@ class EmformerEncoder(nn.Module):
|
||||
right_context at the end.
|
||||
states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa
|
||||
Cached states containing:
|
||||
- past_lens: number of past frames for each sample in batch
|
||||
- attn_caches: attention states from preceding chunk's computation,
|
||||
where each element corresponds to each emformer layer
|
||||
- conv_caches: left context for causal convolution, where each
|
||||
@ -1580,13 +1584,15 @@ class EmformerEncoder(nn.Module):
|
||||
# calcualte padding mask to mask out initial zero caches
|
||||
chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
||||
memory_mask = (
|
||||
torch.div(
|
||||
num_processed_frames, self.chunk_length, rounding_mode="floor"
|
||||
).view(x.size(1), 1)
|
||||
(
|
||||
(num_processed_frames >> self.shift).view(x.size(1), 1)
|
||||
<= torch.arange(self.memory_size, device=x.device).expand(
|
||||
x.size(1), self.memory_size
|
||||
)
|
||||
).flip(1)
|
||||
if self.use_memory
|
||||
else torch.empty(0).to(dtype=torch.bool, device=x.device)
|
||||
)
|
||||
left_context_mask = (
|
||||
num_processed_frames.view(x.size(1), 1)
|
||||
<= torch.arange(self.left_context_length, device=x.device).expand(
|
||||
@ -1631,6 +1637,31 @@ class EmformerEncoder(nn.Module):
|
||||
)
|
||||
return output, output_lengths, output_states
|
||||
|
||||
@torch.jit.export
|
||||
def init_states(self, device: torch.device = torch.device("cpu")):
|
||||
"""Create initial states."""
|
||||
attn_caches = [
|
||||
[
|
||||
torch.zeros(self.memory_size, self.d_model, device=device),
|
||||
torch.zeros(
|
||||
self.left_context_length, self.d_model, device=device
|
||||
),
|
||||
torch.zeros(
|
||||
self.left_context_length, self.d_model, device=device
|
||||
),
|
||||
]
|
||||
for _ in range(self.num_encoder_layers)
|
||||
]
|
||||
conv_caches = [
|
||||
torch.zeros(self.d_model, self.cnn_module_kernel - 1, device=device)
|
||||
for _ in range(self.num_encoder_layers)
|
||||
]
|
||||
states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = (
|
||||
attn_caches,
|
||||
conv_caches,
|
||||
)
|
||||
return states
|
||||
|
||||
|
||||
class Emformer(EncoderInterface):
|
||||
def __init__(
|
||||
@ -1655,6 +1686,7 @@ class Emformer(EncoderInterface):
|
||||
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.right_context_length = right_context_length
|
||||
self.chunk_length = chunk_length
|
||||
if subsampling_factor != 4:
|
||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||
if chunk_length % subsampling_factor != 0:
|
||||
@ -1803,6 +1835,11 @@ class Emformer(EncoderInterface):
|
||||
|
||||
return output, output_lengths, output_states
|
||||
|
||||
@torch.jit.export
|
||||
def init_states(self, device: torch.device = torch.device("cpu")):
|
||||
"""Create initial states."""
|
||||
return self.encoder.init_states(device)
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
|
@ -29,6 +29,7 @@ class Stream(object):
|
||||
def __init__(
|
||||
self,
|
||||
params: AttributeDict,
|
||||
cut_id: str,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
LOG_EPS: float = math.log(1e-10),
|
||||
@ -43,15 +44,13 @@ class Stream(object):
|
||||
device:
|
||||
The device to run this stream.
|
||||
"""
|
||||
self.device = device
|
||||
self.LOG_EPS = LOG_EPS
|
||||
self.cut_id = cut_id
|
||||
|
||||
# Containing attention caches and convolution caches
|
||||
self.states: Optional[
|
||||
Tuple[List[List[torch.Tensor]], List[torch.Tensor]]
|
||||
] = None
|
||||
# Initailize zero states.
|
||||
self.init_states(params)
|
||||
|
||||
# It uses different attributes for different decoding methods.
|
||||
self.context_size = params.context_size
|
||||
@ -107,34 +106,11 @@ class Stream(object):
|
||||
def set_ground_truth(self, ground_truth: str) -> None:
|
||||
self.ground_truth = ground_truth
|
||||
|
||||
def init_states(self, params: AttributeDict) -> None:
|
||||
attn_caches = [
|
||||
[
|
||||
torch.zeros(
|
||||
params.memory_size, params.encoder_dim, device=self.device
|
||||
),
|
||||
torch.zeros(
|
||||
params.left_context_length // params.subsampling_factor,
|
||||
params.encoder_dim,
|
||||
device=self.device,
|
||||
),
|
||||
torch.zeros(
|
||||
params.left_context_length // params.subsampling_factor,
|
||||
params.encoder_dim,
|
||||
device=self.device,
|
||||
),
|
||||
]
|
||||
for _ in range(params.num_encoder_layers)
|
||||
]
|
||||
conv_caches = [
|
||||
torch.zeros(
|
||||
params.encoder_dim,
|
||||
params.cnn_module_kernel - 1,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(params.num_encoder_layers)
|
||||
]
|
||||
self.states = (attn_caches, conv_caches)
|
||||
def set_states(
|
||||
self, states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]]
|
||||
) -> None:
|
||||
"""Set states."""
|
||||
self.states = states
|
||||
|
||||
def get_feature_chunk(self) -> torch.Tensor:
|
||||
"""Get a chunk of feature frames.
|
||||
@ -164,6 +140,10 @@ class Stream(object):
|
||||
"""Return True if all feature frames are processed."""
|
||||
return self._done
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.cut_id
|
||||
|
||||
def decoding_result(self) -> List[int]:
|
||||
"""Obtain current decoding result."""
|
||||
if self.decoding_method == "greedy_search":
|
||||
|
@ -74,7 +74,6 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
from lhotse import CutSet
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
@ -83,6 +82,7 @@ from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from emformer import LOG_EPSILON, stack_states, unstack_states
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from stream import Stream
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
@ -678,11 +678,14 @@ def decode_dataset(
|
||||
# Each utterance has a Stream.
|
||||
stream = Stream(
|
||||
params=params,
|
||||
cut_id=cut.id,
|
||||
decoding_graph=decoding_graph,
|
||||
device=device,
|
||||
LOG_EPS=LOG_EPSILON,
|
||||
)
|
||||
|
||||
stream.set_states(model.encoder.init_states(device))
|
||||
|
||||
audio: np.ndarray = cut.load_audio()
|
||||
# audio.shape: (1, num_samples)
|
||||
assert len(audio.shape) == 2
|
||||
@ -709,6 +712,7 @@ def decode_dataset(
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
(
|
||||
streams[i].id,
|
||||
streams[i].ground_truth.split(),
|
||||
sp.decode(streams[i].decoding_result()).split(),
|
||||
)
|
||||
@ -729,6 +733,7 @@ def decode_dataset(
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
(
|
||||
streams[i].id,
|
||||
streams[i].ground_truth.split(),
|
||||
sp.decode(streams[i].decoding_result()).split(),
|
||||
)
|
||||
|
@ -28,7 +28,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--start-epoch 1 \
|
||||
--exp-dir conv_emformer_transducer_stateless/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 300 \
|
||||
--max-duration 280 \
|
||||
--master-port 12321 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
@ -686,6 +686,15 @@ def compute_loss(
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
||||
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
|
||||
info["utterances"] = feature.size(0)
|
||||
# averaged input duration in frames over utterances
|
||||
info["utt_duration"] = feature_lens.sum().item()
|
||||
# averaged padding proportion over utterances
|
||||
info["utt_pad_proportion"] = (
|
||||
((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||
|
@ -0,0 +1 @@
|
||||
../conv_emformer_transducer_stateless/asr_datamodule.py
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/beam_search.py
|
661
egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py
Executable file
661
egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py
Executable file
@ -0,0 +1,661 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./conv_emformer_transducer_stateless2/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--max-duration 300 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 32 \
|
||||
--right-context-length 8 \
|
||||
--memory-size 32 \
|
||||
--decoding-method greedy_search \
|
||||
--use-averaged-model True
|
||||
|
||||
(2) modified beam search
|
||||
./conv_emformer_transducer_stateless2/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--max-duration 300 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 32 \
|
||||
--right-context-length 8 \
|
||||
--memory-size 32 \
|
||||
--decoding-method modified_beam_search \
|
||||
--use-averaged-model True \
|
||||
--beam-size 4
|
||||
|
||||
(3) fast beam search
|
||||
./conv_emformer_transducer_stateless2/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--max-duration 300 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 32 \
|
||||
--right-context-length 8 \
|
||||
--memory-size 32 \
|
||||
--decoding-method fast_beam_search \
|
||||
--use-averaged-model True \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless4/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
feature_lens += params.chunk_length
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.chunk_length),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
return {
|
||||
(
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
): hyps
|
||||
}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 100
|
||||
else:
|
||||
log_interval = 2
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../conv_emformer_transducer_stateless/decoder.py
|
1841
egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py
Normal file
1841
egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1 @@
|
||||
../conv_emformer_transducer_stateless/encoder_interface.py
|
287
egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py
Executable file
287
egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py
Executable file
@ -0,0 +1,287 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
Usage:
|
||||
./conv_emformer_transducer_stateless2/export.py \
|
||||
--exp-dir ./conv_emformer_transducer_stateless2/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--use-averaged-model=True \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 32 \
|
||||
--right-context-length 8 \
|
||||
--memory-size 32 \
|
||||
--jit False
|
||||
|
||||
It will generate a file exp_dir/pretrained.pt
|
||||
|
||||
To use the generated file with `conv_emformer_transducer_stateless2/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
./conv_emformer_transducer_stateless2/decode.py \
|
||||
--exp-dir ./conv_emformer_transducer_stateless2/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 100 \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model=False \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
--cnn-module-kernel 31 \
|
||||
--left-context-length 32 \
|
||||
--right-context-length 8 \
|
||||
--memory-size 32
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torch.jit.script")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../conv_emformer_transducer_stateless/joiner.py
|
1
egs/librispeech/ASR/conv_emformer_transducer_stateless2/model.py
Symbolic link
1
egs/librispeech/ASR/conv_emformer_transducer_stateless2/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../conv_emformer_transducer_stateless/model.py
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user