diff --git a/.flake8 b/.flake8
index 090e97971..3f1227b9b 100644
--- a/.flake8
+++ b/.flake8
@@ -2,6 +2,9 @@
show-source=true
statistics=true
max-line-length = 80
+per-file-ignores =
+ # line too long
+ egs/librispeech/ASR/conformer_ctc/conformer.py: E501,
exclude =
.git,
diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml
new file mode 100644
index 000000000..876b95e71
--- /dev/null
+++ b/.github/workflows/run-yesno-recipe.yml
@@ -0,0 +1,78 @@
+# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-yesno-recipe
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+jobs:
+ run-yesno-recipe:
+ if: github.event.label.name == 'ready' || github.event_name == 'push'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ # os: [ubuntu-18.04, macos-10.15]
+ # TODO: enable macOS for CPU testing
+ os: [ubuntu-18.04]
+ python-version: [3.8]
+ torch: ["1.8.1"]
+ k2-version: ["1.9.dev20210919"]
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v1
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Install libnsdfile and libsox
+ if: startsWith(matrix.os, 'ubuntu')
+ run: |
+ sudo apt update
+ sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg
+ sudo apt install -q -y --fix-missing sox libsox-dev libsox-fmt-all
+
+ - name: Install Python dependencies
+ run: |
+ python3 -m pip install -U pip
+ pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
+ python3 -m pip install git+https://github.com/lhotse-speech/lhotse
+
+ # We are in ./icefall and there is a file: requirements.txt in it
+ python3 -m pip install -r requirements.txt
+
+ - name: Run yesno recipe
+ shell: bash
+ working-directory: ${{github.workspace}}
+ run: |
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ echo $PYTHONPATH
+
+
+ cd egs/yesno/ASR
+ ./prepare.sh
+ python3 ./tdnn/train.py
+ python3 ./tdnn/decode.py
+ # TODO: Check that the WER is less than some value
diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml
index 20c3363b4..2a743705a 100644
--- a/.github/workflows/style_check.yml
+++ b/.github/workflows/style_check.yml
@@ -45,7 +45,7 @@ jobs:
- name: Install Python dependencies
run: |
- python3 -m pip install --upgrade pip black flake8
+ python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2
- name: Run flake8
shell: bash
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 9a298877a..150b5258a 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -21,18 +21,19 @@ on:
branches:
- master
pull_request:
- branches:
- - master
+ types: [labeled]
jobs:
test:
+ if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04, macos-10.15]
python-version: [3.6, 3.7, 3.8, 3.9]
torch: ["1.8.1"]
- k2-version: ["1.2.dev20210724"]
+ k2-version: ["1.9.dev20210919"]
+
fail-fast: false
steps:
@@ -47,16 +48,24 @@ jobs:
- name: Install Python dependencies
run: |
- python3 -m pip install --upgrade pip pytest kaldialign
+ python3 -m pip install --upgrade pip pytest
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
+ # icefall requirements
+ pip install -r requirements.txt
- # Don't use: pip install lhotse
- # since it installs a version of PyTorch that is not predictable
- git clone --depth 1 https://github.com/lhotse-speech/lhotse
- cd lhotse
- sed -i.bak "/torch/d" requirements.txt
- pip install -r ./requirements.txt
+ - name: Install graphviz
+ if: startsWith(matrix.os, 'ubuntu')
+ shell: bash
+ run: |
+ python3 -m pip install -qq graphviz
+ sudo apt-get -qq install graphviz
+ - name: Install graphviz
+ if: startsWith(matrix.os, 'macos')
+ shell: bash
+ run: |
+ python3 -m pip install -qq graphviz
+ brew install -q graphviz
- name: Run tests
if: startsWith(matrix.os, 'ubuntu')
diff --git a/.gitignore b/.gitignore
index 839a1c34a..e6c84ca5e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,4 +4,4 @@ path.sh
exp
exp*/
*.pt
-download/
+download
diff --git a/README.md b/README.md
index 9ffd34b6d..dc03c5883 100644
--- a/README.md
+++ b/README.md
@@ -1 +1,61 @@
-Working in progress.
+
+
+
+
+## Installation
+
+Please refer to
+for installation.
+
+## Recipes
+
+Please refer to
+for more information.
+
+We provide two recipes at present:
+
+ - [yesno][yesno]
+ - [LibriSpeech][librispeech]
+
+### yesno
+
+This is the simplest ASR recipe in `icefall` and can be run on CPU.
+Training takes less than 30 seconds and gives you the following WER:
+
+```
+[test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
+```
+We do provide a Colab notebook for this recipe.
+
+[](https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing)
+
+
+### LibriSpeech
+
+We provide two models for this recipe: [conformer CTC model][LibriSpeech_conformer_ctc]
+and [TDNN LSTM CTC model][LibriSpeech_tdnn_lstm_ctc].
+
+#### Conformer CTC Model
+
+The best WER we currently have is:
+
+||test-clean|test-other|
+|--|--|--|
+|WER| 2.57% | 5.94% |
+
+We provide a Colab notebook to run a pre-trained conformer CTC model: [](https://colab.research.google.com/drive/1huyupXAcHsUrKaWfI83iMEJ6J0Nh0213?usp=sharing)
+
+#### TDNN LSTM CTC Model
+
+The WER for this model is:
+
+||test-clean|test-other|
+|--|--|--|
+|WER| 6.59% | 17.69% |
+
+We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing)
+
+[LibriSpeech_tdnn_lstm_ctc]: egs/librispeech/ASR/tdnn_lstm_ctc
+[LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc
+[yesno]: egs/yesno/ASR
+[librispeech]: egs/librispeech/ASR
diff --git a/docs/.gitignore b/docs/.gitignore
new file mode 100644
index 000000000..567609b12
--- /dev/null
+++ b/docs/.gitignore
@@ -0,0 +1 @@
+build/
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 000000000..d0c3cbf10
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = source
+BUILDDIR = build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/make.bat b/docs/make.bat
new file mode 100644
index 000000000..6247f7e23
--- /dev/null
+++ b/docs/make.bat
@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=source
+set BUILDDIR=build
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.http://sphinx-doc.org/
+ exit /b 1
+)
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd
diff --git a/docs/requirements.txt b/docs/requirements.txt
new file mode 100644
index 000000000..74640391e
--- /dev/null
+++ b/docs/requirements.txt
@@ -0,0 +1,2 @@
+sphinx_rtd_theme
+sphinx
diff --git a/docs/source/_static/logo.png b/docs/source/_static/logo.png
new file mode 100644
index 000000000..84d42568c
Binary files /dev/null and b/docs/source/_static/logo.png differ
diff --git a/docs/source/conf.py b/docs/source/conf.py
new file mode 100644
index 000000000..599df8b3e
--- /dev/null
+++ b/docs/source/conf.py
@@ -0,0 +1,76 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+# import os
+# import sys
+# sys.path.insert(0, os.path.abspath('.'))
+
+import sphinx_rtd_theme
+
+# -- Project information -----------------------------------------------------
+
+project = "icefall"
+copyright = "2021, icefall development team"
+author = "icefall development team"
+
+# The full version, including alpha/beta/rc tags
+release = "0.1"
+
+
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ "sphinx_rtd_theme",
+]
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ["_templates"]
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = []
+
+source_suffix = {
+ ".rst": "restructuredtext",
+}
+master_doc = "index"
+
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+html_theme = "sphinx_rtd_theme"
+html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
+html_show_sourcelink = True
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ["_static", "installation/images"]
+
+pygments_style = "sphinx"
+
+numfig = True
+
+html_context = {
+ "display_github": True,
+ "github_user": "k2-fsa",
+ "github_repo": "icefall",
+ "github_version": "master",
+ "conf_py_path": "/icefall/docs/source/",
+}
diff --git a/docs/source/contributing/code-style.rst b/docs/source/contributing/code-style.rst
new file mode 100644
index 000000000..7d61a3ba1
--- /dev/null
+++ b/docs/source/contributing/code-style.rst
@@ -0,0 +1,67 @@
+.. _follow the code style:
+
+Follow the code style
+=====================
+
+We use the following tools to make the code style to be as consistent as possible:
+
+ - `black `_, to format the code
+ - `flake8 `_, to check the style and quality of the code
+ - `isort `_, to sort ``imports``
+
+The following versions of the above tools are used:
+
+ - ``black == 12.6b0``
+ - ``flake8 == 3.9.2``
+ - ``isort == 5.9.2``
+
+After running the following commands:
+
+ .. code-block::
+
+ $ git clone https://github.com/k2-fsa/icefall
+ $ cd icefall
+ $ pip install pre-commit
+ $ pre-commit install
+
+it will run the following checks whenever you run ``git commit``, **automatically**:
+
+ .. figure:: images/pre-commit-check.png
+ :width: 600
+ :align: center
+
+ pre-commit hooks invoked by ``git commit`` (Failed).
+
+If any of the above checks failed, your ``git commit`` was not successful.
+Please fix any issues reported by the check tools.
+
+.. HINT::
+
+ Some of the check tools, i.e., ``black`` and ``isort`` will modify
+ the files to be commited **in-place**. So please run ``git status``
+ after failure to see which file has been modified by the tools
+ before you make any further changes.
+
+After fixing all the failures, run ``git commit`` again and
+it should succeed this time:
+
+ .. figure:: images/pre-commit-check-success.png
+ :width: 600
+ :align: center
+
+ pre-commit hooks invoked by ``git commit`` (Succeeded).
+
+If you want to check the style of your code before ``git commit``, you
+can do the following:
+
+ .. code-block:: bash
+
+ $ cd icefall
+ $ pip install black==21.6b0 flake8==3.9.2 isort==5.9.2
+ $ black --check your_changed_file.py
+ $ black your_changed_file.py # modify it in-place
+ $
+ $ flake8 your_changed_file.py
+ $
+ $ isort --check your_changed_file.py # modify it in-place
+ $ isort your_changed_file.py
diff --git a/docs/source/contributing/doc.rst b/docs/source/contributing/doc.rst
new file mode 100644
index 000000000..893d8a15e
--- /dev/null
+++ b/docs/source/contributing/doc.rst
@@ -0,0 +1,45 @@
+Contributing to Documentation
+=============================
+
+We use `sphinx `_
+for documentation.
+
+Before writing documentation, you have to prepare the environment:
+
+ .. code-block:: bash
+
+ $ cd docs
+ $ pip install -r requirements.txt
+
+After setting up the environment, you are ready to write documentation.
+Please refer to `reStructuredText Primer `_
+if you are not familiar with ``reStructuredText``.
+
+After writing some documentation, you can build the documentation **locally**
+to preview what it looks like if it is published:
+
+ .. code-block:: bash
+
+ $ cd docs
+ $ make html
+
+The generated documentation is in ``docs/build/html`` and can be viewed
+with the following commands:
+
+ .. code-block:: bash
+
+ $ cd docs/build/html
+ $ python3 -m http.server
+
+It will print::
+
+ Serving HTTP on 0.0.0.0 port 8000 (http://0.0.0.0:8000/) ...
+
+Open your browser, go to ``_, and you will see
+the following:
+
+ .. figure:: images/doc-contrib.png
+ :width: 600
+ :align: center
+
+ View generated documentation locally with ``python3 -m http.server``.
diff --git a/docs/source/contributing/how-to-create-a-recipe.rst b/docs/source/contributing/how-to-create-a-recipe.rst
new file mode 100644
index 000000000..a30fb9056
--- /dev/null
+++ b/docs/source/contributing/how-to-create-a-recipe.rst
@@ -0,0 +1,156 @@
+How to create a recipe
+======================
+
+.. HINT::
+
+ Please read :ref:`follow the code style` to adjust your code sytle.
+
+.. CAUTION::
+
+ ``icefall`` is designed to be as Pythonic as possible. Please use
+ Python in your recipe if possible.
+
+Data Preparation
+----------------
+
+We recommend you to prepare your training/test/validate dataset
+with `lhotse `_.
+
+Please refer to ``_
+for how to create a recipe in ``lhotse``.
+
+.. HINT::
+
+ The ``yesno`` recipe in ``lhotse`` is a very good example.
+
+ Please refer to ``_,
+ which shows how to add a new recipe to ``lhotse``.
+
+Suppose you would like to add a recipe for a dataset named ``foo``.
+You can do the following:
+
+.. code-block::
+
+ $ cd egs
+ $ mkdir -p foo/ASR
+ $ cd foo/ASR
+ $ touch prepare.sh
+ $ chmod +x prepare.sh
+
+If your dataset is very simple, please follow
+`egs/yesno/ASR/prepare.sh `_
+to write your own ``prepare.sh``.
+Otherwise, please refer to
+`egs/librispeech/ASR/prepare.sh `_
+to prepare your data.
+
+
+Training
+--------
+
+Assume you have a fancy model, called ``bar`` for the ``foo`` recipe, you can
+organize your files in the following way:
+
+.. code-block::
+
+ $ cd egs/foo/ASR
+ $ mkdir bar
+ $ cd bar
+ $ touch README.md model.py train.py decode.py asr_datamodule.py pretrained.py
+
+For instance , the ``yesno`` recipe has a ``tdnn`` model and its directory structure
+looks like the following:
+
+.. code-block:: bash
+
+ egs/yesno/ASR/tdnn/
+ |-- README.md
+ |-- asr_datamodule.py
+ |-- decode.py
+ |-- model.py
+ |-- pretrained.py
+ `-- train.py
+
+**File description**:
+
+ - ``README.md``
+
+ It contains information of this recipe, e.g., how to run it, what the WER is, etc.
+
+ - ``asr_datamodule.py``
+
+ It provides code to create PyTorch dataloaders with train/test/validation dataset.
+
+ - ``decode.py``
+
+ It takes as inputs the checkpoints saved during the training stage to decode the test
+ dataset(s).
+
+ - ``model.py``
+
+ It contains the definition of your fancy neural network model.
+
+ - ``pretrained.py``
+
+ We can use this script to do inference with a pre-trained model.
+
+ - ``train.py``
+
+ It contains training code.
+
+
+.. HINT::
+
+ Please take a look at
+
+ - `egs/yesno/tdnn `_
+ - `egs/librispeech/tdnn_lstm_ctc `_
+ - `egs/librispeech/conformer_ctc `_
+
+ to get a feel what the resulting files look like.
+
+.. NOTE::
+
+ Every model in a recipe is kept to be as self-contained as possible.
+ We tolerate duplicate code among different recipes.
+
+
+The training stage should be invocable by:
+
+ .. code-block::
+
+ $ cd egs/foo/ASR
+ $ ./bar/train.py
+ $ ./bar/train.py --help
+
+
+Decoding
+--------
+
+Please refer to
+
+ - ``_
+
+ If your model is transformer/conformer based.
+
+ - ``_
+
+ If your model is TDNN/LSTM based, i.e., there is no attention decoder.
+
+ - ``_
+
+ If there is no LM rescoring.
+
+The decoding stage should be invocable by:
+
+ .. code-block::
+
+ $ cd egs/foo/ASR
+ $ ./bar/decode.py
+ $ ./bar/decode.py --help
+
+Pre-trained model
+-----------------
+
+Please demonstrate how to use your model for inference in ``egs/foo/ASR/bar/pretrained.py``.
+If possible, please consider creating a Colab notebook to show that.
diff --git a/docs/source/contributing/images/doc-contrib.png b/docs/source/contributing/images/doc-contrib.png
new file mode 100644
index 000000000..00906ab83
Binary files /dev/null and b/docs/source/contributing/images/doc-contrib.png differ
diff --git a/docs/source/contributing/images/pre-commit-check-success.png b/docs/source/contributing/images/pre-commit-check-success.png
new file mode 100644
index 000000000..3c6ee9b1c
Binary files /dev/null and b/docs/source/contributing/images/pre-commit-check-success.png differ
diff --git a/docs/source/contributing/images/pre-commit-check.png b/docs/source/contributing/images/pre-commit-check.png
new file mode 100644
index 000000000..80784eced
Binary files /dev/null and b/docs/source/contributing/images/pre-commit-check.png differ
diff --git a/docs/source/contributing/index.rst b/docs/source/contributing/index.rst
new file mode 100644
index 000000000..21c747d33
--- /dev/null
+++ b/docs/source/contributing/index.rst
@@ -0,0 +1,22 @@
+Contributing
+============
+
+Contributions to ``icefall`` are very welcomed.
+There are many possible ways to make contributions and
+two of them are:
+
+ - To write documentation
+ - To write code
+
+ - (1) To follow the code style in the repository
+ - (2) To write a new recipe
+
+In this page, we describe how to contribute documentation
+and code to ``icefall``.
+
+.. toctree::
+ :maxdepth: 2
+
+ doc
+ code-style
+ how-to-create-a-recipe
diff --git a/docs/source/index.rst b/docs/source/index.rst
new file mode 100644
index 000000000..b06047a89
--- /dev/null
+++ b/docs/source/index.rst
@@ -0,0 +1,25 @@
+.. icefall documentation master file, created by
+ sphinx-quickstart on Mon Aug 23 16:07:39 2021.
+ You can adapt this file completely to your liking, but it should at least
+ contain the root `toctree` directive.
+
+Icefall
+=======
+
+.. image:: _static/logo.png
+ :alt: icefall logo
+ :width: 168px
+ :align: center
+ :target: https://github.com/k2-fsa/icefall
+
+
+Documentation for `icefall `_, containing
+speech recognition recipes using `k2 `_.
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Contents:
+
+ installation/index
+ recipes/index
+ contributing/index
diff --git a/docs/source/installation/images/device-CPU_CUDA-orange.svg b/docs/source/installation/images/device-CPU_CUDA-orange.svg
new file mode 100644
index 000000000..a023a1283
--- /dev/null
+++ b/docs/source/installation/images/device-CPU_CUDA-orange.svg
@@ -0,0 +1 @@
+
diff --git a/docs/source/installation/images/k2-v1.9-blueviolet.svg b/docs/source/installation/images/k2-v1.9-blueviolet.svg
new file mode 100644
index 000000000..5a207b370
--- /dev/null
+++ b/docs/source/installation/images/k2-v1.9-blueviolet.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/source/installation/images/os-Linux_macOS-ff69b4.svg b/docs/source/installation/images/os-Linux_macOS-ff69b4.svg
new file mode 100644
index 000000000..178813ed4
--- /dev/null
+++ b/docs/source/installation/images/os-Linux_macOS-ff69b4.svg
@@ -0,0 +1 @@
+
diff --git a/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg b/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg
new file mode 100644
index 000000000..befc1e19e
--- /dev/null
+++ b/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg
@@ -0,0 +1 @@
+
diff --git a/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg b/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg
new file mode 100644
index 000000000..496e5a9ef
--- /dev/null
+++ b/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg
@@ -0,0 +1 @@
+
diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst
new file mode 100644
index 000000000..f960033e8
--- /dev/null
+++ b/docs/source/installation/index.rst
@@ -0,0 +1,466 @@
+.. _install icefall:
+
+Installation
+============
+
+- |os|
+- |device|
+- |python_versions|
+- |torch_versions|
+- |k2_versions|
+
+.. |os| image:: ./images/os-Linux_macOS-ff69b4.svg
+ :alt: Supported operating systems
+
+.. |device| image:: ./images/device-CPU_CUDA-orange.svg
+ :alt: Supported devices
+
+.. |python_versions| image:: ./images/python-3.6_3.7_3.8_3.9-blue.svg
+ :alt: Supported python versions
+
+.. |torch_versions| image:: ./images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg
+ :alt: Supported PyTorch versions
+
+.. |k2_versions| image:: ./images/k2-v1.9-blueviolet.svg
+ :alt: Supported k2 versions
+
+``icefall`` depends on `k2 `_ and
+`lhotse `_.
+
+We recommend you to install ``k2`` first, as ``k2`` is bound to
+a specific version of PyTorch after compilation. Install ``k2`` also
+installs its dependency PyTorch, which can be reused by ``lhotse``.
+
+
+(1) Install k2
+--------------
+
+Please refer to ``_
+to install ``k2``.
+
+.. CAUTION::
+
+ You need to install ``k2`` with a version at least **v1.9**.
+
+.. HINT::
+
+ If you have already installed PyTorch and don't want to replace it,
+ please install a version of ``k2`` that is compiled against the version
+ of PyTorch you are using.
+
+(2) Install lhotse
+------------------
+
+Please refer to ``_
+to install ``lhotse``.
+
+.. HINT::
+
+ Install ``lhotse`` also installs its dependency `torchaudio `_.
+
+.. CAUTION::
+
+ If you have installed ``torchaudio``, please consider uninstalling it before
+ installing ``lhotse``. Otherwise, it may update your already installed PyTorch.
+
+(3) Download icefall
+--------------------
+
+``icefall`` is a collection of Python scripts, so you don't need to install it
+and we don't provide a ``setup.py`` to install it.
+
+What you need is to download it and set the environment variable ``PYTHONPATH``
+to point to it.
+
+Assume you want to place ``icefall`` in the folder ``/tmp``. The
+following commands show you how to setup ``icefall``:
+
+
+.. code-block:: bash
+
+ cd /tmp
+ git clone https://github.com/k2-fsa/icefall
+ cd icefall
+ pip install -r requirements.txt
+ export PYTHONPATH=/tmp/icefall:$PYTHONPATH
+
+.. HINT::
+
+ You can put several versions of ``icefall`` in the same virtual environment.
+ To switch among different versions of ``icefall``, just set ``PYTHONPATH``
+ to point to the version you want.
+
+
+Installation example
+--------------------
+
+The following shows an example about setting up the environment.
+
+
+(1) Create a virtual environment
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ $ virtualenv -p python3.8 test-icefall
+
+ created virtual environment CPython3.8.6.final.0-64 in 1540ms
+ creator CPython3Posix(dest=/ceph-fj/fangjun/test-icefall, clear=False, no_vcs_ignore=False, global=False)
+ seeder FromAppData(download=False, pip=bundle, setuptools=bundle, wheel=bundle, via=copy, app_data_dir=/root/fangjun/.local/share/v
+ irtualenv)
+ added seed packages: pip==21.1.3, setuptools==57.4.0, wheel==0.36.2
+ activators BashActivator,CShellActivator,FishActivator,PowerShellActivator,PythonActivator,XonshActivator
+
+
+(2) Activate your virtual environment
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ $ source test-icefall/bin/activate
+
+(3) Install k2
+~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ $ pip install k2==1.4.dev20210822+cpu.torch1.9.0 -f https://k2-fsa.org/nightly/index.html
+
+ Looking in links: https://k2-fsa.org/nightly/index.html
+ Collecting k2==1.4.dev20210822+cpu.torch1.9.0
+ Downloading https://k2-fsa.org/nightly/whl/k2-1.4.dev20210822%2Bcpu.torch1.9.0-cp38-cp38-linux_x86_64.whl (1.6 MB)
+ |________________________________| 1.6 MB 185 kB/s
+ Collecting graphviz
+ Downloading graphviz-0.17-py3-none-any.whl (18 kB)
+ Collecting torch==1.9.0
+ Using cached torch-1.9.0-cp38-cp38-manylinux1_x86_64.whl (831.4 MB)
+ Collecting typing-extensions
+ Using cached typing_extensions-3.10.0.0-py3-none-any.whl (26 kB)
+ Installing collected packages: typing-extensions, torch, graphviz, k2
+ Successfully installed graphviz-0.17 k2-1.4.dev20210822+cpu.torch1.9.0 torch-1.9.0 typing-extensions-3.10.0.0
+
+.. WARNING::
+
+ We choose to install a CPU version of k2 for testing. You would probably want to install
+ a CUDA version of k2.
+
+
+(4) Install lhotse
+~~~~~~~~~~~~~~~~~~
+
+.. code-block::
+
+ $ pip install git+https://github.com/lhotse-speech/lhotse
+
+ Collecting git+https://github.com/lhotse-speech/lhotse
+ Cloning https://github.com/lhotse-speech/lhotse to /tmp/pip-req-build-7b1b76ge
+ Running command git clone -q https://github.com/lhotse-speech/lhotse /tmp/pip-req-build-7b1b76ge
+ Collecting audioread>=2.1.9
+ Using cached audioread-2.1.9-py3-none-any.whl
+ Collecting SoundFile>=0.10
+ Using cached SoundFile-0.10.3.post1-py2.py3-none-any.whl (21 kB)
+ Collecting click>=7.1.1
+ Using cached click-8.0.1-py3-none-any.whl (97 kB)
+ Collecting cytoolz>=0.10.1
+ Using cached cytoolz-0.11.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.9 MB)
+ Collecting dataclasses
+ Using cached dataclasses-0.6-py3-none-any.whl (14 kB)
+ Collecting h5py>=2.10.0
+ Downloading h5py-3.4.0-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (4.5 MB)
+ |________________________________| 4.5 MB 684 kB/s
+ Collecting intervaltree>=3.1.0
+ Using cached intervaltree-3.1.0-py2.py3-none-any.whl
+ Collecting lilcom>=1.1.0
+ Using cached lilcom-1.1.1-cp38-cp38-linux_x86_64.whl
+ Collecting numpy>=1.18.1
+ Using cached numpy-1.21.2-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (15.8 MB)
+ Collecting packaging
+ Using cached packaging-21.0-py3-none-any.whl (40 kB)
+ Collecting pyyaml>=5.3.1
+ Using cached PyYAML-5.4.1-cp38-cp38-manylinux1_x86_64.whl (662 kB)
+ Collecting tqdm
+ Downloading tqdm-4.62.1-py2.py3-none-any.whl (76 kB)
+ |________________________________| 76 kB 2.7 MB/s
+ Collecting torchaudio==0.9.0
+ Downloading torchaudio-0.9.0-cp38-cp38-manylinux1_x86_64.whl (1.9 MB)
+ |________________________________| 1.9 MB 73.1 MB/s
+ Requirement already satisfied: torch==1.9.0 in ./test-icefall/lib/python3.8/site-packages (from torchaudio==0.9.0->lhotse===0.8.0.dev
+ -2a1410b-clean) (1.9.0)
+ Requirement already satisfied: typing-extensions in ./test-icefall/lib/python3.8/site-packages (from torch==1.9.0->torchaudio==0.9.0-
+ >lhotse===0.8.0.dev-2a1410b-clean) (3.10.0.0)
+ Collecting toolz>=0.8.0
+ Using cached toolz-0.11.1-py3-none-any.whl (55 kB)
+ Collecting sortedcontainers<3.0,>=2.0
+ Using cached sortedcontainers-2.4.0-py2.py3-none-any.whl (29 kB)
+ Collecting cffi>=1.0
+ Using cached cffi-1.14.6-cp38-cp38-manylinux1_x86_64.whl (411 kB)
+ Collecting pycparser
+ Using cached pycparser-2.20-py2.py3-none-any.whl (112 kB)
+ Collecting pyparsing>=2.0.2
+ Using cached pyparsing-2.4.7-py2.py3-none-any.whl (67 kB)
+ Building wheels for collected packages: lhotse
+ Building wheel for lhotse (setup.py) ... done
+ Created wheel for lhotse: filename=lhotse-0.8.0.dev_2a1410b_clean-py3-none-any.whl size=342242 sha256=f683444afa4dc0881133206b4646a
+ 9d0f774224cc84000f55d0a67f6e4a37997
+ Stored in directory: /tmp/pip-ephem-wheel-cache-ftu0qysz/wheels/7f/7a/8e/a0bf241336e2e3cb573e1e21e5600952d49f5162454f2e612f
+ WARNING: Built wheel for lhotse is invalid: Metadata 1.2 mandates PEP 440 version, but '0.8.0.dev-2a1410b-clean' is not
+ Failed to build lhotse
+ Installing collected packages: pycparser, toolz, sortedcontainers, pyparsing, numpy, cffi, tqdm, torchaudio, SoundFile, pyyaml, packa
+ ging, lilcom, intervaltree, h5py, dataclasses, cytoolz, click, audioread, lhotse
+ Running setup.py install for lhotse ... done
+ DEPRECATION: lhotse was installed using the legacy 'setup.py install' method, because a wheel could not be built for it. A possible
+ replacement is to fix the wheel build issue reported above. You can find discussion regarding this at https://github.com/pypa/pip/is
+ sues/8368.
+ Successfully installed SoundFile-0.10.3.post1 audioread-2.1.9 cffi-1.14.6 click-8.0.1 cytoolz-0.11.0 dataclasses-0.6 h5py-3.4.0 inter
+ valtree-3.1.0 lhotse-0.8.0.dev-2a1410b-clean lilcom-1.1.1 numpy-1.21.2 packaging-21.0 pycparser-2.20 pyparsing-2.4.7 pyyaml-5.4.1 sor
+ tedcontainers-2.4.0 toolz-0.11.1 torchaudio-0.9.0 tqdm-4.62.1
+
+(5) Download icefall
+~~~~~~~~~~~~~~~~~~~~
+
+.. code-block::
+
+ $ cd /tmp
+ $ git clone https://github.com/k2-fsa/icefall
+
+ Cloning into 'icefall'...
+ remote: Enumerating objects: 500, done.
+ remote: Counting objects: 100% (500/500), done.
+ remote: Compressing objects: 100% (308/308), done.
+ remote: Total 500 (delta 263), reused 307 (delta 102), pack-reused 0
+ Receiving objects: 100% (500/500), 172.49 KiB | 385.00 KiB/s, done.
+ Resolving deltas: 100% (263/263), done.
+
+ $ cd icefall
+ $ pip install -r requirements.txt
+
+ Collecting kaldilm
+ Downloading kaldilm-1.8.tar.gz (48 kB)
+ |________________________________| 48 kB 574 kB/s
+ Collecting kaldialign
+ Using cached kaldialign-0.2-cp38-cp38-linux_x86_64.whl
+ Collecting sentencepiece>=0.1.96
+ Using cached sentencepiece-0.1.96-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
+ Collecting tensorboard
+ Using cached tensorboard-2.6.0-py3-none-any.whl (5.6 MB)
+ Requirement already satisfied: setuptools>=41.0.0 in /ceph-fj/fangjun/test-icefall/lib/python3.8/site-packages (from tensorboard->-r
+ requirements.txt (line 4)) (57.4.0)
+ Collecting absl-py>=0.4
+ Using cached absl_py-0.13.0-py3-none-any.whl (132 kB)
+ Collecting google-auth-oauthlib<0.5,>=0.4.1
+ Using cached google_auth_oauthlib-0.4.5-py2.py3-none-any.whl (18 kB)
+ Collecting grpcio>=1.24.3
+ Using cached grpcio-1.39.0-cp38-cp38-manylinux2014_x86_64.whl (4.3 MB)
+ Requirement already satisfied: wheel>=0.26 in /ceph-fj/fangjun/test-icefall/lib/python3.8/site-packages (from tensorboard->-r require
+ ments.txt (line 4)) (0.36.2)
+ Requirement already satisfied: numpy>=1.12.0 in /ceph-fj/fangjun/test-icefall/lib/python3.8/site-packages (from tensorboard->-r requi
+ rements.txt (line 4)) (1.21.2)
+ Collecting protobuf>=3.6.0
+ Using cached protobuf-3.17.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.0 MB)
+ Collecting werkzeug>=0.11.15
+ Using cached Werkzeug-2.0.1-py3-none-any.whl (288 kB)
+ Collecting tensorboard-data-server<0.7.0,>=0.6.0
+ Using cached tensorboard_data_server-0.6.1-py3-none-manylinux2010_x86_64.whl (4.9 MB)
+ Collecting google-auth<2,>=1.6.3
+ Downloading google_auth-1.35.0-py2.py3-none-any.whl (152 kB)
+ |________________________________| 152 kB 1.4 MB/s
+ Collecting requests<3,>=2.21.0
+ Using cached requests-2.26.0-py2.py3-none-any.whl (62 kB)
+ Collecting tensorboard-plugin-wit>=1.6.0
+ Using cached tensorboard_plugin_wit-1.8.0-py3-none-any.whl (781 kB)
+ Collecting markdown>=2.6.8
+ Using cached Markdown-3.3.4-py3-none-any.whl (97 kB)
+ Collecting six
+ Using cached six-1.16.0-py2.py3-none-any.whl (11 kB)
+ Collecting cachetools<5.0,>=2.0.0
+ Using cached cachetools-4.2.2-py3-none-any.whl (11 kB)
+ Collecting rsa<5,>=3.1.4
+ Using cached rsa-4.7.2-py3-none-any.whl (34 kB)
+ Collecting pyasn1-modules>=0.2.1
+ Using cached pyasn1_modules-0.2.8-py2.py3-none-any.whl (155 kB)
+ Collecting requests-oauthlib>=0.7.0
+ Using cached requests_oauthlib-1.3.0-py2.py3-none-any.whl (23 kB)
+ Collecting pyasn1<0.5.0,>=0.4.6
+ Using cached pyasn1-0.4.8-py2.py3-none-any.whl (77 kB)
+ Collecting urllib3<1.27,>=1.21.1
+ Using cached urllib3-1.26.6-py2.py3-none-any.whl (138 kB)
+ Collecting certifi>=2017.4.17
+ Using cached certifi-2021.5.30-py2.py3-none-any.whl (145 kB)
+ Collecting charset-normalizer~=2.0.0
+ Using cached charset_normalizer-2.0.4-py3-none-any.whl (36 kB)
+ Collecting idna<4,>=2.5
+ Using cached idna-3.2-py3-none-any.whl (59 kB)
+ Collecting oauthlib>=3.0.0
+ Using cached oauthlib-3.1.1-py2.py3-none-any.whl (146 kB)
+ Building wheels for collected packages: kaldilm
+ Building wheel for kaldilm (setup.py) ... done
+ Created wheel for kaldilm: filename=kaldilm-1.8-cp38-cp38-linux_x86_64.whl size=897233 sha256=eccb906cafcd45bf9a7e1a1718e4534254bfb
+ f4c0d0cbc66eee6c88d68a63862
+ Stored in directory: /root/fangjun/.cache/pip/wheels/85/7d/63/f2dd586369b8797cb36d213bf3a84a789eeb92db93d2e723c9
+ Successfully built kaldilm
+ Installing collected packages: urllib3, pyasn1, idna, charset-normalizer, certifi, six, rsa, requests, pyasn1-modules, oauthlib, cach
+ etools, requests-oauthlib, google-auth, werkzeug, tensorboard-plugin-wit, tensorboard-data-server, protobuf, markdown, grpcio, google
+ -auth-oauthlib, absl-py, tensorboard, sentencepiece, kaldilm, kaldialign
+ Successfully installed absl-py-0.13.0 cachetools-4.2.2 certifi-2021.5.30 charset-normalizer-2.0.4 google-auth-1.35.0 google-auth-oaut
+ hlib-0.4.5 grpcio-1.39.0 idna-3.2 kaldialign-0.2 kaldilm-1.8 markdown-3.3.4 oauthlib-3.1.1 protobuf-3.17.3 pyasn1-0.4.8 pyasn1-module
+ s-0.2.8 requests-2.26.0 requests-oauthlib-1.3.0 rsa-4.7.2 sentencepiece-0.1.96 six-1.16.0 tensorboard-2.6.0 tensorboard-data-server-0
+ .6.1 tensorboard-plugin-wit-1.8.0 urllib3-1.26.6 werkzeug-2.0.1
+
+
+Test Your Installation
+----------------------
+
+To test that your installation is successful, let us run
+the `yesno recipe `_
+on CPU.
+
+Data preparation
+~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ $ export PYTHONPATH=/tmp/icefall:$PYTHONPATH
+ $ cd /tmp/icefall
+ $ cd egs/yesno/ASR
+ $ ./prepare.sh
+
+The log of running ``./prepare.sh`` is:
+
+.. code-block::
+
+ 2021-08-23 19:27:26 (prepare.sh:24:main) dl_dir: /tmp/icefall/egs/yesno/ASR/download
+ 2021-08-23 19:27:26 (prepare.sh:27:main) stage 0: Download data
+ Downloading waves_yesno.tar.gz: 4.49MB [00:03, 1.39MB/s]
+ 2021-08-23 19:27:30 (prepare.sh:36:main) Stage 1: Prepare yesno manifest
+ 2021-08-23 19:27:31 (prepare.sh:42:main) Stage 2: Compute fbank for yesno
+ 2021-08-23 19:27:32,803 INFO [compute_fbank_yesno.py:52] Processing train
+ Extracting and storing features: 100%|_______________________________________________________________| 90/90 [00:01<00:00, 80.57it/s]
+ 2021-08-23 19:27:34,085 INFO [compute_fbank_yesno.py:52] Processing test
+ Extracting and storing features: 100%|______________________________________________________________| 30/30 [00:00<00:00, 248.21it/s]
+ 2021-08-23 19:27:34 (prepare.sh:48:main) Stage 3: Prepare lang
+ 2021-08-23 19:27:35 (prepare.sh:63:main) Stage 4: Prepare G
+ /tmp/pip-install-fcordre9/kaldilm_6899d26f2d684ad48f21025950cd2866/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Rea
+ d(std::istream&):79
+ [I] Reading \data\ section.
+ /tmp/pip-install-fcordre9/kaldilm_6899d26f2d684ad48f21025950cd2866/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Rea
+ d(std::istream&):140
+ [I] Reading \1-grams: section.
+ 2021-08-23 19:27:35 (prepare.sh:89:main) Stage 5: Compile HLG
+ 2021-08-23 19:27:35,928 INFO [compile_hlg.py:120] Processing data/lang_phone
+ 2021-08-23 19:27:35,929 INFO [lexicon.py:116] Converting L.pt to Linv.pt
+ 2021-08-23 19:27:35,931 INFO [compile_hlg.py:48] Building ctc_topo. max_token_id: 3
+ 2021-08-23 19:27:35,932 INFO [compile_hlg.py:52] Loading G.fst.txt
+ 2021-08-23 19:27:35,932 INFO [compile_hlg.py:62] Intersecting L and G
+ 2021-08-23 19:27:35,933 INFO [compile_hlg.py:64] LG shape: (4, None)
+ 2021-08-23 19:27:35,933 INFO [compile_hlg.py:66] Connecting LG
+ 2021-08-23 19:27:35,933 INFO [compile_hlg.py:68] LG shape after k2.connect: (4, None)
+ 2021-08-23 19:27:35,933 INFO [compile_hlg.py:70]
+ 2021-08-23 19:27:35,933 INFO [compile_hlg.py:71] Determinizing LG
+ 2021-08-23 19:27:35,934 INFO [compile_hlg.py:74]
+ 2021-08-23 19:27:35,934 INFO [compile_hlg.py:76] Connecting LG after k2.determinize
+ 2021-08-23 19:27:35,934 INFO [compile_hlg.py:79] Removing disambiguation symbols on LG
+ 2021-08-23 19:27:35,934 INFO [compile_hlg.py:87] LG shape after k2.remove_epsilon: (6, None)
+ 2021-08-23 19:27:35,935 INFO [compile_hlg.py:92] Arc sorting LG
+ 2021-08-23 19:27:35,935 INFO [compile_hlg.py:95] Composing H and LG
+ 2021-08-23 19:27:35,935 INFO [compile_hlg.py:102] Connecting LG
+ 2021-08-23 19:27:35,935 INFO [compile_hlg.py:105] Arc sorting LG
+ 2021-08-23 19:27:35,936 INFO [compile_hlg.py:107] HLG.shape: (8, None)
+ 2021-08-23 19:27:35,936 INFO [compile_hlg.py:123] Saving HLG.pt to data/lang_phone
+
+
+Training
+~~~~~~~~
+
+Now let us run the training part:
+
+.. code-block::
+
+ $ export CUDA_VISIBLE_DEVICES=""
+ $ ./tdnn/train.py
+
+.. CAUTION::
+
+ We use ``export CUDA_VISIBLE_DEVICES=""`` so that ``icefall`` uses CPU
+ even if there are GPUs available.
+
+The training log is given below:
+
+.. code-block::
+
+ 2021-08-23 19:30:31,072 INFO [train.py:465] Training started
+ 2021-08-23 19:30:31,072 INFO [train.py:466] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01,
+ 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, '
+ best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_doub
+ le_scores': True, 'world_size': 1, 'master_port': 12354, 'tensorboard': True, 'num_epochs': 15, 'feature_dir': PosixPath('data/fbank'
+ ), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0
+ , 'on_the_fly_feats': False, 'shuffle': True, 'return_cuts': True, 'num_workers': 2}
+ 2021-08-23 19:30:31,074 INFO [lexicon.py:113] Loading pre-compiled data/lang_phone/Linv.pt
+ 2021-08-23 19:30:31,098 INFO [asr_datamodule.py:146] About to get train cuts
+ 2021-08-23 19:30:31,098 INFO [asr_datamodule.py:240] About to get train cuts
+ 2021-08-23 19:30:31,102 INFO [asr_datamodule.py:149] About to create train dataset
+ 2021-08-23 19:30:31,102 INFO [asr_datamodule.py:200] Using SingleCutSampler.
+ 2021-08-23 19:30:31,102 INFO [asr_datamodule.py:206] About to create train dataloader
+ 2021-08-23 19:30:31,102 INFO [asr_datamodule.py:219] About to get test cuts
+ 2021-08-23 19:30:31,102 INFO [asr_datamodule.py:246] About to get test cuts
+ 2021-08-23 19:30:31,357 INFO [train.py:416] Epoch 0, batch 0, batch avg loss 1.0789, total avg loss: 1.0789, batch size: 4
+ 2021-08-23 19:30:31,848 INFO [train.py:416] Epoch 0, batch 10, batch avg loss 0.5356, total avg loss: 0.7556, batch size: 4
+ 2021-08-23 19:30:32,301 INFO [train.py:432] Epoch 0, valid loss 0.9972, best valid loss: 0.9972 best valid epoch: 0
+ 2021-08-23 19:30:32,805 INFO [train.py:416] Epoch 0, batch 20, batch avg loss 0.2436, total avg loss: 0.5717, batch size: 3
+ 2021-08-23 19:30:33,109 INFO [train.py:432] Epoch 0, valid loss 0.4167, best valid loss: 0.4167 best valid epoch: 0
+ 2021-08-23 19:30:33,121 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-0.pt
+ 2021-08-23 19:30:33,325 INFO [train.py:416] Epoch 1, batch 0, batch avg loss 0.2214, total avg loss: 0.2214, batch size: 5
+ 2021-08-23 19:30:33,798 INFO [train.py:416] Epoch 1, batch 10, batch avg loss 0.0781, total avg loss: 0.1343, batch size: 5
+ 2021-08-23 19:30:34,065 INFO [train.py:432] Epoch 1, valid loss 0.0859, best valid loss: 0.0859 best valid epoch: 1
+ 2021-08-23 19:30:34,556 INFO [train.py:416] Epoch 1, batch 20, batch avg loss 0.0421, total avg loss: 0.0975, batch size: 3
+ 2021-08-23 19:30:34,810 INFO [train.py:432] Epoch 1, valid loss 0.0431, best valid loss: 0.0431 best valid epoch: 1
+ 2021-08-23 19:30:34,824 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-1.pt
+
+ ... ...
+
+ 2021-08-23 19:30:49,657 INFO [train.py:416] Epoch 13, batch 0, batch avg loss 0.0109, total avg loss: 0.0109, batch size: 5
+ 2021-08-23 19:30:49,984 INFO [train.py:416] Epoch 13, batch 10, batch avg loss 0.0093, total avg loss: 0.0096, batch size: 4
+ 2021-08-23 19:30:50,239 INFO [train.py:432] Epoch 13, valid loss 0.0104, best valid loss: 0.0101 best valid epoch: 12
+ 2021-08-23 19:30:50,569 INFO [train.py:416] Epoch 13, batch 20, batch avg loss 0.0092, total avg loss: 0.0096, batch size: 2
+ 2021-08-23 19:30:50,819 INFO [train.py:432] Epoch 13, valid loss 0.0101, best valid loss: 0.0101 best valid epoch: 13
+ 2021-08-23 19:30:50,835 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-13.pt
+ 2021-08-23 19:30:51,024 INFO [train.py:416] Epoch 14, batch 0, batch avg loss 0.0105, total avg loss: 0.0105, batch size: 5
+ 2021-08-23 19:30:51,317 INFO [train.py:416] Epoch 14, batch 10, batch avg loss 0.0099, total avg loss: 0.0097, batch size: 4
+ 2021-08-23 19:30:51,552 INFO [train.py:432] Epoch 14, valid loss 0.0108, best valid loss: 0.0101 best valid epoch: 13
+ 2021-08-23 19:30:51,869 INFO [train.py:416] Epoch 14, batch 20, batch avg loss 0.0096, total avg loss: 0.0097, batch size: 5
+ 2021-08-23 19:30:52,107 INFO [train.py:432] Epoch 14, valid loss 0.0102, best valid loss: 0.0101 best valid epoch: 13
+ 2021-08-23 19:30:52,126 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-14.pt
+ 2021-08-23 19:30:52,128 INFO [train.py:537] Done!
+
+Decoding
+~~~~~~~~
+
+Let us use the trained model to decode the test set:
+
+.. code-block::
+
+ $ ./tdnn/decode.py
+
+The decoding log is:
+
+.. code-block::
+
+ 2021-08-23 19:35:30,192 INFO [decode.py:249] Decoding started
+ 2021-08-23 19:35:30,192 INFO [decode.py:250] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': True, 'return_cuts': True, 'num_workers': 2}
+ 2021-08-23 19:35:30,193 INFO [lexicon.py:113] Loading pre-compiled data/lang_phone/Linv.pt
+ 2021-08-23 19:35:30,213 INFO [decode.py:259] device: cpu
+ 2021-08-23 19:35:30,217 INFO [decode.py:279] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
+ /tmp/icefall/icefall/checkpoint.py:146: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch.
+ It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
+ To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at /pytorch/aten/src/ATen/native/BinaryOps.cpp:450.)
+ avg[k] //= n
+ 2021-08-23 19:35:30,220 INFO [asr_datamodule.py:219] About to get test cuts
+ 2021-08-23 19:35:30,220 INFO [asr_datamodule.py:246] About to get test cuts
+ 2021-08-23 19:35:30,409 INFO [decode.py:190] batch 0/8, cuts processed until now is 4
+ 2021-08-23 19:35:30,571 INFO [decode.py:228] The transcripts are stored in tdnn/exp/recogs-test_set.txt
+ 2021-08-23 19:35:30,572 INFO [utils.py:317] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
+ 2021-08-23 19:35:30,573 INFO [decode.py:236] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
+ 2021-08-23 19:35:30,573 INFO [decode.py:299] Done!
+
+**Congratulations!** You have successfully setup the environment and have run the first recipe in ``icefall``.
+
+Have fun with ``icefall``!
diff --git a/docs/source/recipes/images/yesno-tdnn-tensorboard-log.png b/docs/source/recipes/images/yesno-tdnn-tensorboard-log.png
new file mode 100644
index 000000000..3d2612c9c
Binary files /dev/null and b/docs/source/recipes/images/yesno-tdnn-tensorboard-log.png differ
diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst
new file mode 100644
index 000000000..36f8dfc39
--- /dev/null
+++ b/docs/source/recipes/index.rst
@@ -0,0 +1,17 @@
+Recipes
+=======
+
+This page contains various recipes in ``icefall``.
+Currently, only speech recognition recipes are provided.
+
+We may add recipes for other tasks as well in the future.
+
+.. we put the yesno recipe as the first recipe since it is the simplest one.
+.. Other recipes are listed in a alphabetical order.
+
+.. toctree::
+ :maxdepth: 2
+
+ yesno
+
+ librispeech
diff --git a/docs/source/recipes/librispeech.rst b/docs/source/recipes/librispeech.rst
new file mode 100644
index 000000000..946b23407
--- /dev/null
+++ b/docs/source/recipes/librispeech.rst
@@ -0,0 +1,10 @@
+LibriSpeech
+===========
+
+We provide the following models for the LibriSpeech dataset:
+
+.. toctree::
+ :maxdepth: 2
+
+ librispeech/tdnn_lstm_ctc
+ librispeech/conformer_ctc
diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst
new file mode 100644
index 000000000..40100bc5a
--- /dev/null
+++ b/docs/source/recipes/librispeech/conformer_ctc.rst
@@ -0,0 +1,631 @@
+Confromer CTC
+=============
+
+This tutorial shows you how to run a conformer ctc model
+with the `LibriSpeech `_ dataset.
+
+
+.. HINT::
+
+ We assume you have read the page :ref:`install icefall` and have setup
+ the environment for ``icefall``.
+
+.. HINT::
+
+ We recommend you to use a GPU or several GPUs to run this recipe.
+
+In this tutorial, you will learn:
+
+ - (1) How to prepare data for training and decoding
+ - (2) How to start the training, either with a single GPU or multiple GPUs
+ - (3) How to do decoding after training, with n-gram LM rescoring and attention decoder rescoring
+ - (4) How to use a pre-trained model, provided by us
+
+Data preparation
+----------------
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./prepare.sh
+
+The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
+All you need to do is to run it.
+
+The data preparation contains several stages, you can use the following two
+options:
+
+ - ``--stage``
+ - ``--stop-stage``
+
+to control which stage(s) should be run. By default, all stages are executed.
+
+
+For example,
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./prepare.sh --stage 0 --stop-stage 0
+
+means to run only stage 0.
+
+To run stage 2 to stage 5, use:
+
+.. code-block:: bash
+
+ $ ./prepare.sh --stage 2 --stop-stage 5
+
+.. HINT::
+
+ If you have pre-downloaded the `LibriSpeech `_
+ dataset and the `musan `_ dataset, say,
+ they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify
+ the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that
+ ``./prepare.sh`` won't re-download them.
+
+.. NOTE::
+
+ All generated files by ``./prepare.sh``, e.g., features, lexicon, etc,
+ are saved in ``./data`` directory.
+
+
+Training
+--------
+
+Configurable options
+~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./conformer_ctc/train.py --help
+
+shows you the training options that can be passed from the commandline.
+The following options are used quite often:
+
+ - ``--full-libri``
+
+ If it's True, the training part uses all the training data, i.e.,
+ 960 hours. Otherwise, the training part uses only the subset
+ ``train-clean-100``, which has 100 hours of training data.
+
+ .. CAUTION::
+
+ The training set is perturbed by speed with two factors: 0.9 and 1.1.
+ If ``--full-libri`` is True, each epoch actually processes
+ ``3x960 == 2880`` hours of data.
+
+ - ``--num-epochs``
+
+ It is the number of epochs to train. For instance,
+ ``./conformer_ctc/train.py --num-epochs 30`` trains for 30 epochs
+ and generates ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-29.pt``
+ in the folder ``./conformer_ctc/exp``.
+
+ - ``--start-epoch``
+
+ It's used to resume training.
+ ``./conformer_ctc/train.py --start-epoch 10`` loads the
+ checkpoint ``./conformer_ctc/exp/epoch-9.pt`` and starts
+ training from epoch 10, based on the state from epoch 9.
+
+ - ``--world-size``
+
+ It is used for multi-GPU single-machine DDP training.
+
+ - (a) If it is 1, then no DDP training is used.
+
+ - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training.
+
+ The following shows some use cases with it.
+
+ **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and
+ GPU 2 for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ export CUDA_VISIBLE_DEVICES="0,2"
+ $ ./conformer_ctc/train.py --world-size 2
+
+ **Use case 2**: You have 4 GPUs and you want to use all of them
+ for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./conformer_ctc/train.py --world-size 4
+
+ **Use case 3**: You have 4 GPUs but you only want to use GPU 3
+ for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ export CUDA_VISIBLE_DEVICES="3"
+ $ ./conformer_ctc/train.py --world-size 1
+
+ .. CAUTION::
+
+ Only multi-GPU single-machine DDP training is implemented at present.
+ Multi-GPU multi-machine DDP training will be added later.
+
+ - ``--max-duration``
+
+ It specifies the number of seconds over all utterances in a
+ batch, before **padding**.
+ If you encounter CUDA OOM, please reduce it. For instance, if
+ your are using V100 NVIDIA GPU, we recommend you to set it to ``200``.
+
+ .. HINT::
+
+ Due to padding, the number of seconds of all utterances in a
+ batch will usually be larger than ``--max-duration``.
+
+ A larger value for ``--max-duration`` may cause OOM during training,
+ while a smaller value may increase the training time. You have to
+ tune it.
+
+
+Pre-configured options
+~~~~~~~~~~~~~~~~~~~~~~
+
+There are some training options, e.g., weight decay,
+number of warmup steps, results dir, etc,
+that are not passed from the commandline.
+They are pre-configured by the function ``get_params()`` in
+`conformer_ctc/train.py `_
+
+You don't need to change these pre-configured parameters. If you really need to change
+them, please modify ``./conformer_ctc/train.py`` directly.
+
+
+Training logs
+~~~~~~~~~~~~~
+
+Training logs and checkpoints are saved in ``conformer_ctc/exp``.
+You will find the following files in that directory:
+
+ - ``epoch-0.pt``, ``epoch-1.pt``, ...
+
+ These are checkpoint files, containing model ``state_dict`` and optimizer ``state_dict``.
+ To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
+
+ .. code-block:: bash
+
+ $ ./conformer_ctc/train.py --start-epoch 11
+
+ - ``tensorboard/``
+
+ This folder contains TensorBoard logs. Training loss, validation loss, learning
+ rate, etc, are recorded in these logs. You can visualize them by:
+
+ .. code-block:: bash
+
+ $ cd conformer_ctc/exp/tensorboard
+ $ tensorboard dev upload --logdir . --description "Conformer CTC training for LibriSpeech with icefall"
+
+ It will print something like below:
+
+ .. code-block::
+
+ TensorFlow installation not found - running with reduced feature set.
+ Upload started and will continue reading any new data as it's added to the logdir.
+
+ To stop uploading, press Ctrl-C.
+
+ New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/lzGnETjwRxC3yghNMd4kPw/
+
+ [2021-08-24T16:42:43] Started scanning logdir.
+ Uploading 4540 scalars...
+
+ Note there is a URL in the above output, click it and you will see
+ the following screenshot:
+
+ .. figure:: images/librispeech-conformer-ctc-tensorboard-log.png
+ :width: 600
+ :alt: TensorBoard screenshot
+ :align: center
+ :target: https://tensorboard.dev/experiment/lzGnETjwRxC3yghNMd4kPw/
+
+ TensorBoard screenshot.
+
+ - ``log/log-train-xxxx``
+
+ It is the detailed training log in text format, same as the one
+ you saw printed to the console during training.
+
+Usage examples
+~~~~~~~~~~~~~~
+
+The following shows typical use cases:
+
+**Case 1**
+^^^^^^^^^^
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./conformer_ctc/train.py --max-duration 200 --full-libri 0
+
+It uses ``--max-duration`` of 200 to avoid OOM. Also, it uses only
+a subset of the LibriSpeech data for training.
+
+
+**Case 2**
+^^^^^^^^^^
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ export CUDA_VISIBLE_DEVICES="0,3"
+ $ ./conformer_ctc/train.py --world-size 2
+
+It uses GPU 0 and GPU 3 for DDP training.
+
+**Case 3**
+^^^^^^^^^^
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./conformer_ctc/train.py --num-epochs 10 --start-epoch 3
+
+It loads checkpoint ``./conformer_ctc/exp/epoch-2.pt`` and starts
+training from epoch 3. Also, it trains for 10 epochs.
+
+Decoding
+--------
+
+The decoding part uses checkpoints saved by the training part, so you have
+to run the training part first.
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./conformer_ctc/decode.py --help
+
+shows the options for decoding.
+
+The commonly used options are:
+
+ - ``--method``
+
+ This specifies the decoding method.
+
+ The following command uses attention decoder for rescoring:
+
+ .. code-block::
+
+ $ cd egs/librispeech/ASR
+ $ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --lattice-score-scale 0.5
+
+ - ``--lattice-score-scale``
+
+ It is used to scale down lattice scores so that there are more unique
+ paths for rescoring.
+
+ - ``--max-duration``
+
+ It has the same meaning as the one during training. A larger
+ value may cause OOM.
+
+Pre-trained Model
+-----------------
+
+We have uploaded a pre-trained model to
+``_.
+
+We describe how to use the pre-trained model to transcribe a sound file or
+multiple sound files in the following.
+
+Install kaldifeat
+~~~~~~~~~~~~~~~~~
+
+`kaldifeat `_ is used to
+extract features for a single sound file or multiple sound files
+at the same time.
+
+Please refer to ``_ for installation.
+
+Download the pre-trained model
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The following commands describe how to download the pre-trained model:
+
+.. code-block::
+
+ $ cd egs/librispeech/ASR
+ $ mkdir tmp
+ $ cd tmp
+ $ git lfs install
+ $ git clone https://huggingface.co/pkufool/icefall_asr_librispeech_conformer_ctc
+
+.. CAUTION::
+
+ You have to use ``git lfs`` to download the pre-trained model.
+
+.. CAUTION::
+
+ In order to use this pre-trained model, your k2 version has to be v1.7 or later.
+
+After downloading, you will have the following files:
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ tree tmp
+
+.. code-block:: bash
+
+ tmp
+ `-- icefall_asr_librispeech_conformer_ctc
+ |-- README.md
+ |-- data
+ | |-- lang_bpe
+ | | |-- HLG.pt
+ | | |-- bpe.model
+ | | |-- tokens.txt
+ | | `-- words.txt
+ | `-- lm
+ | `-- G_4_gram.pt
+ |-- exp
+ | `-- pretrained.pt
+ `-- test_wavs
+ |-- 1089-134686-0001.flac
+ |-- 1221-135766-0001.flac
+ |-- 1221-135766-0002.flac
+ `-- trans.txt
+
+ 6 directories, 11 files
+
+**File descriptions**:
+
+ - ``data/lang_bpe/HLG.pt``
+
+ It is the decoding graph.
+
+ - ``data/lang_bpe/bpe.model``
+
+ It is a sentencepiece model. You can use it to reproduce our results.
+
+ - ``data/lang_bpe/tokens.txt``
+
+ It contains tokens and their IDs, generated from ``bpe.model``.
+ Provided only for convenience so that you can look up the SOS/EOS ID easily.
+
+ - ``data/lang_bpe/words.txt``
+
+ It contains words and their IDs.
+
+ - ``data/lm/G_4_gram.pt``
+
+ It is a 4-gram LM, used for n-gram LM rescoring.
+
+ - ``exp/pretrained.pt``
+
+ It contains pre-trained model parameters, obtained by averaging
+ checkpoints from ``epoch-15.pt`` to ``epoch-34.pt``.
+ Note: We have removed optimizer ``state_dict`` to reduce file size.
+
+ - ``test_waves/*.flac``
+
+ It contains some test sound files from LibriSpeech ``test-clean`` dataset.
+
+ - ``test_waves/trans.txt``
+
+ It contains the reference transcripts for the sound files in ``test_waves/``.
+
+The information of the test sound files is listed below:
+
+.. code-block:: bash
+
+ $ soxi tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/*.flac
+
+ Input File : 'tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac'
+ Channels : 1
+ Sample Rate : 16000
+ Precision : 16-bit
+ Duration : 00:00:06.62 = 106000 samples ~ 496.875 CDDA sectors
+ File Size : 116k
+ Bit Rate : 140k
+ Sample Encoding: 16-bit FLAC
+
+ Input File : 'tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac'
+ Channels : 1
+ Sample Rate : 16000
+ Precision : 16-bit
+ Duration : 00:00:16.71 = 267440 samples ~ 1253.62 CDDA sectors
+ File Size : 343k
+ Bit Rate : 164k
+ Sample Encoding: 16-bit FLAC
+
+ Input File : 'tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac'
+ Channels : 1
+ Sample Rate : 16000
+ Precision : 16-bit
+ Duration : 00:00:04.83 = 77200 samples ~ 361.875 CDDA sectors
+ File Size : 105k
+ Bit Rate : 174k
+ Sample Encoding: 16-bit FLAC
+
+ Total Duration of 3 files: 00:00:28.16
+
+Usage
+~~~~~
+
+.. code-block::
+
+ $ cd egs/librispeech/ASR
+ $ ./conformer_ctc/pretrained.py --help
+
+displays the help information.
+
+It supports three decoding methods:
+
+ - HLG decoding
+ - HLG + n-gram LM rescoring
+ - HLG + n-gram LM rescoring + attention decoder rescoring
+
+HLG decoding
+^^^^^^^^^^^^
+
+HLG decoding uses the best path of the decoding lattice as the decoding result.
+
+The command to run HLG decoding is:
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./conformer_ctc/pretrained.py \
+ --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \
+ --words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \
+ --HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \
+ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \
+ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \
+ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac
+
+The output is given below:
+
+.. code-block::
+
+ 2021-08-20 11:03:05,712 INFO [pretrained.py:217] device: cuda:0
+ 2021-08-20 11:03:05,712 INFO [pretrained.py:219] Creating model
+ 2021-08-20 11:03:11,345 INFO [pretrained.py:238] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt
+ 2021-08-20 11:03:18,442 INFO [pretrained.py:255] Constructing Fbank computer
+ 2021-08-20 11:03:18,444 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac']
+ 2021-08-20 11:03:18,507 INFO [pretrained.py:271] Decoding started
+ 2021-08-20 11:03:18,795 INFO [pretrained.py:300] Use HLG decoding
+ 2021-08-20 11:03:19,149 INFO [pretrained.py:339]
+ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac:
+ AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
+
+ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac:
+ GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED
+ BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
+
+ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac:
+ YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
+
+ 2021-08-20 11:03:19,149 INFO [pretrained.py:341] Decoding Done
+
+HLG decoding + LM rescoring
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+It uses an n-gram LM to rescore the decoding lattice and the best
+path of the rescored lattice is the decoding result.
+
+The command to run HLG decoding + LM rescoring is:
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./conformer_ctc/pretrained.py \
+ --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \
+ --words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \
+ --HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \
+ --method whole-lattice-rescoring \
+ --G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \
+ --ngram-lm-scale 0.8 \
+ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \
+ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \
+ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac
+
+Its output is:
+
+.. code-block::
+
+ 2021-08-20 11:12:17,565 INFO [pretrained.py:217] device: cuda:0
+ 2021-08-20 11:12:17,565 INFO [pretrained.py:219] Creating model
+ 2021-08-20 11:12:23,728 INFO [pretrained.py:238] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt
+ 2021-08-20 11:12:30,035 INFO [pretrained.py:246] Loading G from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt
+ 2021-08-20 11:13:10,779 INFO [pretrained.py:255] Constructing Fbank computer
+ 2021-08-20 11:13:10,787 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac']
+ 2021-08-20 11:13:10,798 INFO [pretrained.py:271] Decoding started
+ 2021-08-20 11:13:11,085 INFO [pretrained.py:305] Use HLG decoding + LM rescoring
+ 2021-08-20 11:13:11,736 INFO [pretrained.py:339]
+ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac:
+ AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
+
+ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac:
+ GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED
+ BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
+
+ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac:
+ YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
+
+ 2021-08-20 11:13:11,737 INFO [pretrained.py:341] Decoding Done
+
+HLG decoding + LM rescoring + attention decoder rescoring
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+It uses an n-gram LM to rescore the decoding lattice, extracts
+n paths from the rescored lattice, recores the extracted paths with
+an attention decoder. The path with the highest score is the decoding result.
+
+The command to run HLG decoding + LM rescoring + attention decoder rescoring is:
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./conformer_ctc/pretrained.py \
+ --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \
+ --words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \
+ --HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \
+ --method attention-decoder \
+ --G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \
+ --ngram-lm-scale 1.3 \
+ --attention-decoder-scale 1.2 \
+ --lattice-score-scale 0.5 \
+ --num-paths 100 \
+ --sos-id 1 \
+ --eos-id 1 \
+ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \
+ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \
+ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac
+
+The output is below:
+
+.. code-block::
+
+ 2021-08-20 11:19:11,397 INFO [pretrained.py:217] device: cuda:0
+ 2021-08-20 11:19:11,397 INFO [pretrained.py:219] Creating model
+ 2021-08-20 11:19:17,354 INFO [pretrained.py:238] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt
+ 2021-08-20 11:19:24,615 INFO [pretrained.py:246] Loading G from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt
+ 2021-08-20 11:20:04,576 INFO [pretrained.py:255] Constructing Fbank computer
+ 2021-08-20 11:20:04,584 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac']
+ 2021-08-20 11:20:04,595 INFO [pretrained.py:271] Decoding started
+ 2021-08-20 11:20:04,854 INFO [pretrained.py:313] Use HLG + LM rescoring + attention decoder rescoring
+ 2021-08-20 11:20:05,805 INFO [pretrained.py:339]
+ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac:
+ AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
+
+ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac:
+ GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED
+ BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
+
+ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac:
+ YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
+
+ 2021-08-20 11:20:05,805 INFO [pretrained.py:341] Decoding Done
+
+Colab notebook
+--------------
+
+We do provide a colab notebook for this recipe showing how to use a pre-trained model.
+
+|librispeech asr conformer ctc colab notebook|
+
+.. |librispeech asr conformer ctc colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg
+ :target: https://colab.research.google.com/drive/1huyupXAcHsUrKaWfI83iMEJ6J0Nh0213?usp=sharing
+
+.. HINT::
+
+ Due to limited memory provided by Colab, you have to upgrade to Colab Pro to
+ run ``HLG decoding + LM rescoring`` and
+ ``HLG decoding + LM rescoring + attention decoder rescoring``.
+ Otherwise, you can only run ``HLG decoding`` with Colab.
+
+**Congratulations!** You have finished the librispeech ASR recipe with
+conformer CTC models in ``icefall``.
diff --git a/docs/source/recipes/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png b/docs/source/recipes/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png
new file mode 100644
index 000000000..4e8c2ea7c
Binary files /dev/null and b/docs/source/recipes/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png differ
diff --git a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst
new file mode 100644
index 000000000..848026802
--- /dev/null
+++ b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst
@@ -0,0 +1,394 @@
+TDNN-LSTM-CTC
+=============
+
+This tutorial shows you how to run a TDNN-LSTM-CTC model with the `LibriSpeech `_ dataset.
+
+
+.. HINT::
+
+ We assume you have read the page :ref:`install icefall` and have setup
+ the environment for ``icefall``.
+
+
+Data preparation
+----------------
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./prepare.sh
+
+The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
+All you need to do is to run it.
+
+The data preparation contains several stages, you can use the following two
+options:
+
+ - ``--stage``
+ - ``--stop-stage``
+
+to control which stage(s) should be run. By default, all stages are executed.
+
+
+For example,
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./prepare.sh --stage 0 --stop-stage 0
+
+means to run only stage 0.
+
+To run stage 2 to stage 5, use:
+
+.. code-block:: bash
+
+ $ ./prepare.sh --stage 2 --stop-stage 5
+
+
+Training
+--------
+
+Now describing the training of TDNN-LSTM-CTC model, contained in
+the `tdnn_lstm_ctc `_
+folder.
+
+The command to run the training part is:
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ export CUDA_VISIBLE_DEVICES="0,1,2,3"
+ $ ./tdnn_lstm_ctc/train.py --world-size 4
+
+By default, it will run ``20`` epochs. Training logs and checkpoints are saved
+in ``tdnn_lstm_ctc/exp``.
+
+In ``tdnn_lstm_ctc/exp``, you will find the following files:
+
+ - ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-19.pt``
+
+ These are checkpoint files, containing model ``state_dict`` and optimizer ``state_dict``.
+ To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
+
+ .. code-block:: bash
+
+ $ ./tdnn_lstm_ctc/train.py --start-epoch 11
+
+ - ``tensorboard/``
+
+ This folder contains TensorBoard logs. Training loss, validation loss, learning
+ rate, etc, are recorded in these logs. You can visualize them by:
+
+ .. code-block:: bash
+
+ $ cd tdnn_lstm_ctc/exp/tensorboard
+ $ tensorboard dev upload --logdir . --description "TDNN LSTM training for librispeech with icefall"
+
+ - ``log/log-train-xxxx``
+
+ It is the detailed training log in text format, same as the one
+ you saw printed to the console during training.
+
+
+To see available training options, you can use:
+
+.. code-block:: bash
+
+ $ ./tdnn_lstm_ctc/train.py --help
+
+Other training options, e.g., learning rate, results dir, etc., are
+pre-configured in the function ``get_params()``
+in `tdnn_lstm_ctc/train.py `_.
+Normally, you don't need to change them. You can change them by modifying the code, if
+you want.
+
+Decoding
+--------
+
+The decoding part uses checkpoints saved by the training part, so you have
+to run the training part first.
+
+The command for decoding is:
+
+.. code-block:: bash
+
+ $ export CUDA_VISIBLE_DEVICES="0"
+ $ ./tdnn_lstm_ctc/decode.py
+
+You will see the WER in the output log.
+
+Decoded results are saved in ``tdnn_lstm_ctc/exp``.
+
+.. code-block:: bash
+
+ $ ./tdnn_lstm_ctc/decode.py --help
+
+shows you the available decoding options.
+
+Some commonly used options are:
+
+ - ``--epoch``
+
+ You can select which checkpoint to be used for decoding.
+ For instance, ``./tdnn_lstm_ctc/decode.py --epoch 10`` means to use
+ ``./tdnn_lstm_ctc/exp/epoch-10.pt`` for decoding.
+
+ - ``--avg``
+
+ It's related to model averaging. It specifies number of checkpoints
+ to be averaged. The averaged model is used for decoding.
+ For example, the following command:
+
+ .. code-block:: bash
+
+ $ ./tdnn_lstm_ctc/decode.py --epoch 10 --avg 3
+
+ uses the average of ``epoch-8.pt``, ``epoch-9.pt`` and ``epoch-10.pt``
+ for decoding.
+
+ - ``--export``
+
+ If it is ``True``, i.e., ``./tdnn_lstm_ctc/decode.py --export 1``, the code
+ will save the averaged model to ``tdnn_lstm_ctc/exp/pretrained.pt``.
+ See :ref:`tdnn_lstm_ctc use a pre-trained model` for how to use it.
+
+
+.. _tdnn_lstm_ctc use a pre-trained model:
+
+Pre-trained Model
+-----------------
+
+We have uploaded the pre-trained model to
+``_.
+
+The following shows you how to use the pre-trained model.
+
+
+Install kaldifeat
+~~~~~~~~~~~~~~~~~
+
+`kaldifeat `_ is used to
+extract features for a single sound file or multiple sound files
+at the same time.
+
+Please refer to ``_ for installation.
+
+Download the pre-trained model
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ mkdir tmp
+ $ cd tmp
+ $ git lfs install
+ $ git clone https://huggingface.co/pkufool/icefall_asr_librispeech_tdnn-lstm_ctc
+
+.. CAUTION::
+
+ You have to use ``git lfs`` to download the pre-trained model.
+
+.. CAUTION::
+
+ In order to use this pre-trained model, your k2 version has to be v1.7 or later.
+
+After downloading, you will have the following files:
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ tree tmp
+
+.. code-block:: bash
+
+ tmp/
+ `-- icefall_asr_librispeech_tdnn-lstm_ctc
+ |-- README.md
+ |-- data
+ | |-- lang_phone
+ | | |-- HLG.pt
+ | | |-- tokens.txt
+ | | `-- words.txt
+ | `-- lm
+ | `-- G_4_gram.pt
+ |-- exp
+ | `-- pretrained.pt
+ `-- test_wavs
+ |-- 1089-134686-0001.flac
+ |-- 1221-135766-0001.flac
+ |-- 1221-135766-0002.flac
+ `-- trans.txt
+
+ 6 directories, 10 files
+
+**File descriptions**:
+
+ - ``data/lang_phone/HLG.pt``
+
+ It is the decoding graph.
+
+ - ``data/lang_phone/tokens.txt``
+
+ It contains tokens and their IDs.
+
+ - ``data/lang_phone/words.txt``
+
+ It contains words and their IDs.
+
+ - ``data/lm/G_4_gram.pt``
+
+ It is a 4-gram LM, useful for LM rescoring.
+
+ - ``exp/pretrained.pt``
+
+ It contains pre-trained model parameters, obtained by averaging
+ checkpoints from ``epoch-14.pt`` to ``epoch-19.pt``.
+ Note: We have removed optimizer ``state_dict`` to reduce file size.
+
+ - ``test_waves/*.flac``
+
+ It contains some test sound files from LibriSpeech ``test-clean`` dataset.
+
+ - ``test_waves/trans.txt``
+
+ It contains the reference transcripts for the sound files in ``test_waves/``.
+
+The information of the test sound files is listed below:
+
+.. code-block:: bash
+
+ $ soxi tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/*.flac
+
+ Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac'
+ Channels : 1
+ Sample Rate : 16000
+ Precision : 16-bit
+ Duration : 00:00:06.62 = 106000 samples ~ 496.875 CDDA sectors
+ File Size : 116k
+ Bit Rate : 140k
+ Sample Encoding: 16-bit FLAC
+
+
+ Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac'
+ Channels : 1
+ Sample Rate : 16000
+ Precision : 16-bit
+ Duration : 00:00:16.71 = 267440 samples ~ 1253.62 CDDA sectors
+ File Size : 343k
+ Bit Rate : 164k
+ Sample Encoding: 16-bit FLAC
+
+
+ Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac'
+ Channels : 1
+ Sample Rate : 16000
+ Precision : 16-bit
+ Duration : 00:00:04.83 = 77200 samples ~ 361.875 CDDA sectors
+ File Size : 105k
+ Bit Rate : 174k
+ Sample Encoding: 16-bit FLAC
+
+ Total Duration of 3 files: 00:00:28.16
+
+
+Inference with a pre-trained model
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./tdnn_lstm_ctc/pretrained.py --help
+
+shows the usage information of ``./tdnn_lstm_ctc/pretrained.py``.
+
+To decode with ``1best`` method, we can use:
+
+.. code-block:: bash
+
+ ./tdnn_lstm_ctc/pretrained.py \
+ --checkpoint ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretraind.pt \
+ --words-file ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/words.txt \
+ --HLG ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt \
+ ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac \
+ ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac \
+ ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac
+
+The output is:
+
+.. code-block::
+
+ 2021-08-24 16:57:13,315 INFO [pretrained.py:168] device: cuda:0
+ 2021-08-24 16:57:13,315 INFO [pretrained.py:170] Creating model
+ 2021-08-24 16:57:18,331 INFO [pretrained.py:182] Loading HLG from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt
+ 2021-08-24 16:57:27,581 INFO [pretrained.py:199] Constructing Fbank computer
+ 2021-08-24 16:57:27,584 INFO [pretrained.py:209] Reading sound files: ['./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac']
+ 2021-08-24 16:57:27,599 INFO [pretrained.py:215] Decoding started
+ 2021-08-24 16:57:27,791 INFO [pretrained.py:245] Use HLG decoding
+ 2021-08-24 16:57:28,098 INFO [pretrained.py:266]
+ ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac:
+ AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
+
+ ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac:
+ GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
+
+ ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac:
+ YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
+
+
+ 2021-08-24 16:57:28,099 INFO [pretrained.py:268] Decoding Done
+
+
+To decode with ``whole-lattice-rescoring`` methond, you can use
+
+.. code-block:: bash
+
+ ./tdnn_lstm_ctc/pretrained.py \
+ --checkpoint ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretraind.pt \
+ --words-file ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/words.txt \
+ --HLG ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt \
+ --method whole-lattice-rescoring \
+ --G ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lm/G_4_gram.pt \
+ --ngram-lm-scale 0.8 \
+ ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac \
+ ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac \
+ ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac
+
+The decoding output is:
+
+.. code-block::
+
+ 2021-08-24 16:39:24,725 INFO [pretrained.py:168] device: cuda:0
+ 2021-08-24 16:39:24,725 INFO [pretrained.py:170] Creating model
+ 2021-08-24 16:39:29,403 INFO [pretrained.py:182] Loading HLG from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt
+ 2021-08-24 16:39:40,631 INFO [pretrained.py:190] Loading G from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lm/G_4_gram.pt
+ 2021-08-24 16:39:53,098 INFO [pretrained.py:199] Constructing Fbank computer
+ 2021-08-24 16:39:53,107 INFO [pretrained.py:209] Reading sound files: ['./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac']
+ 2021-08-24 16:39:53,121 INFO [pretrained.py:215] Decoding started
+ 2021-08-24 16:39:53,443 INFO [pretrained.py:250] Use HLG decoding + LM rescoring
+ 2021-08-24 16:39:54,010 INFO [pretrained.py:266]
+ ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac:
+ AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
+
+ ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac:
+ GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
+
+ ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac:
+ YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
+
+
+ 2021-08-24 16:39:54,010 INFO [pretrained.py:268] Decoding Done
+
+
+Colab notebook
+--------------
+
+We provide a colab notebook for decoding with pre-trained model.
+
+|librispeech tdnn_lstm_ctc colab notebook|
+
+.. |librispeech tdnn_lstm_ctc colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg
+ :target: https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd
+
+
+**Congratulations!** You have finished the TDNN-LSTM-CTC recipe on librispeech in ``icefall``.
diff --git a/docs/source/recipes/yesno.rst b/docs/source/recipes/yesno.rst
new file mode 100644
index 000000000..cb425ad1d
--- /dev/null
+++ b/docs/source/recipes/yesno.rst
@@ -0,0 +1,445 @@
+yesno
+=====
+
+This page shows you how to run the `yesno `_ recipe. It contains:
+
+ - (1) Prepare data for training
+ - (2) Train a TDNN model
+
+ - (a) View text format logs and visualize TensorBoard logs
+ - (b) Select device type, i.e., CPU and GPU, for training
+ - (c) Change training options
+ - (d) Resume training from a checkpoint
+
+ - (3) Decode with a trained model
+
+ - (a) Select a checkpoint for decoding
+ - (b) Model averaging
+
+ - (4) Colab notebook
+
+ - (a) It shows you step by step how to setup the environment, how to do training,
+ and how to do decoding
+ - (b) How to use a pre-trained model
+
+ - (5) Inference with a pre-trained model
+
+ - (a) Download a pre-trained model, provided by us
+ - (b) Decode a single sound file with a pre-trained model
+ - (c) Decode multiple sound files at the same time
+
+It does **NOT** show you:
+
+ - (1) How to train with multiple GPUs
+
+ The ``yesno`` dataset is so small that CPU is more than enough
+ for training as well as for decoding.
+
+ - (2) How to use LM rescoring for decoding
+
+ The dataset does not have an LM for rescoring.
+
+.. HINT::
+
+ We assume you have read the page :ref:`install icefall` and have setup
+ the environment for ``icefall``.
+
+.. HINT::
+
+ You **don't** need a **GPU** to run this recipe. It can be run on a **CPU**.
+ The training part takes less than 30 **seconds** on a CPU and you will get
+ the following WER at the end::
+
+ [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
+
+Data preparation
+----------------
+
+.. code-block:: bash
+
+ $ cd egs/yesno/ASR
+ $ ./prepare.sh
+
+The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
+All you need to do is to run it.
+
+The data preparation contains several stages, you can use the following two
+options:
+
+ - ``--stage``
+ - ``--stop-stage``
+
+to control which stage(s) should be run. By default, all stages are executed.
+
+
+For example,
+
+.. code-block:: bash
+
+ $ cd egs/yesno/ASR
+ $ ./prepare.sh --stage 0 --stop-stage 0
+
+means to run only stage 0.
+
+To run stage 2 to stage 5, use:
+
+.. code-block:: bash
+
+ $ ./prepare.sh --stage 2 --stop-stage 5
+
+
+Training
+--------
+
+We provide only a TDNN model, contained in
+the `tdnn `_
+folder, for ``yesno``.
+
+The command to run the training part is:
+
+.. code-block:: bash
+
+ $ cd egs/yesno/ASR
+ $ export CUDA_VISIBLE_DEVICES=""
+ $ ./tdnn/train.py
+
+By default, it will run ``15`` epochs. Training logs and checkpoints are saved
+in ``tdnn/exp``.
+
+In ``tdnn/exp``, you will find the following files:
+
+ - ``epoch-0.pt``, ``epoch-1.pt``, ...
+
+ These are checkpoint files, containing model ``state_dict`` and optimizer ``state_dict``.
+ To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
+
+ .. code-block:: bash
+
+ $ ./tdnn/train.py --start-epoch 11
+
+ - ``tensorboard/``
+
+ This folder contains TensorBoard logs. Training loss, validation loss, learning
+ rate, etc, are recorded in these logs. You can visualize them by:
+
+ .. code-block:: bash
+
+ $ cd tdnn/exp/tensorboard
+ $ tensorboard dev upload --logdir . --description "TDNN training for yesno with icefall"
+
+ It will print something like below:
+
+ .. code-block::
+
+ TensorFlow installation not found - running with reduced feature set.
+ Upload started and will continue reading any new data as it's added to the logdir.
+
+ To stop uploading, press Ctrl-C.
+
+ New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/yKUbhb5wRmOSXYkId1z9eg/
+
+ [2021-08-23T23:49:41] Started scanning logdir.
+ [2021-08-23T23:49:42] Total uploaded: 135 scalars, 0 tensors, 0 binary objects
+ Listening for new data in logdir...
+
+ Note there is a URL in the above output, click it and you will see
+ the following screenshot:
+
+ .. figure:: images/yesno-tdnn-tensorboard-log.png
+ :width: 600
+ :alt: TensorBoard screenshot
+ :align: center
+ :target: https://tensorboard.dev/experiment/yKUbhb5wRmOSXYkId1z9eg/
+
+ TensorBoard screenshot.
+
+ - ``log/log-train-xxxx``
+
+ It is the detailed training log in text format, same as the one
+ you saw printed to the console during training.
+
+
+
+.. NOTE::
+
+ By default, ``./tdnn/train.py`` uses GPU 0 for training if GPUs are available.
+ If you have two GPUs, say, GPU 0 and GPU 1, and you want to use GPU 1 for
+ training, you can run:
+
+ .. code-block:: bash
+
+ $ export CUDA_VISIBLE_DEVICES="1"
+ $ ./tdnn/train.py
+
+ Since the ``yesno`` dataset is very small, containing only 30 sound files
+ for training, and the model in use is also very small, we use:
+
+ .. code-block:: bash
+
+ $ export CUDA_VISIBLE_DEVICES=""
+
+ so that ``./tdnn/train.py`` uses CPU during training.
+
+ If you don't have GPUs, then you don't need to
+ run ``export CUDA_VISIBLE_DEVICES=""``.
+
+To see available training options, you can use:
+
+.. code-block:: bash
+
+ $ ./tdnn/train.py --help
+
+Other training options, e.g., learning rate, results dir, etc., are
+pre-configured in the function ``get_params()``
+in `tdnn/train.py `_.
+Normally, you don't need to change them. You can change them by modifying the code, if
+you want.
+
+Decoding
+--------
+
+The decoding part uses checkpoints saved by the training part, so you have
+to run the training part first.
+
+The command for decoding is:
+
+.. code-block:: bash
+
+ $ export CUDA_VISIBLE_DEVICES=""
+ $ ./tdnn/decode.py
+
+You will see the WER in the output log.
+
+Decoded results are saved in ``tdnn/exp``.
+
+.. code-block:: bash
+
+ $ ./tdnn/decode.py --help
+
+shows you the available decoding options.
+
+Some commonly used options are:
+
+ - ``--epoch``
+
+ You can select which checkpoint to be used for decoding.
+ For instance, ``./tdnn/decode.py --epoch 10`` means to use
+ ``./tdnn/exp/epoch-10.pt`` for decoding.
+
+ - ``--avg``
+
+ It's related to model averaging. It specifies number of checkpoints
+ to be averaged. The averaged model is used for decoding.
+ For example, the following command:
+
+ .. code-block:: bash
+
+ $ ./tdnn/decode.py --epoch 10 --avg 3
+
+ uses the average of ``epoch-8.pt``, ``epoch-9.pt`` and ``epoch-10.pt``
+ for decoding.
+
+ - ``--export``
+
+ If it is ``True``, i.e., ``./tdnn/decode.py --export 1``, the code
+ will save the averaged model to ``tdnn/exp/pretrained.pt``.
+ See :ref:`yesno use a pre-trained model` for how to use it.
+
+
+.. _yesno use a pre-trained model:
+
+Pre-trained Model
+-----------------
+
+We have uploaded the pre-trained model to
+``_.
+
+The following shows you how to use the pre-trained model.
+
+Download the pre-trained model
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ $ cd egs/yesno/ASR
+ $ mkdir tmp
+ $ cd tmp
+ $ git lfs install
+ $ git clone https://huggingface.co/csukuangfj/icefall_asr_yesno_tdnn
+
+.. CAUTION::
+
+ You have to use ``git lfs`` to download the pre-trained model.
+
+After downloading, you will have the following files:
+
+.. code-block:: bash
+
+ $ cd egs/yesno/ASR
+ $ tree tmp
+
+.. code-block:: bash
+
+ tmp/
+ `-- icefall_asr_yesno_tdnn
+ |-- README.md
+ |-- lang_phone
+ | |-- HLG.pt
+ | |-- L.pt
+ | |-- L_disambig.pt
+ | |-- Linv.pt
+ | |-- lexicon.txt
+ | |-- lexicon_disambig.txt
+ | |-- tokens.txt
+ | `-- words.txt
+ |-- lm
+ | |-- G.arpa
+ | `-- G.fst.txt
+ |-- pretrained.pt
+ `-- test_waves
+ |-- 0_0_0_1_0_0_0_1.wav
+ |-- 0_0_1_0_0_0_1_0.wav
+ |-- 0_0_1_0_0_1_1_1.wav
+ |-- 0_0_1_0_1_0_0_1.wav
+ |-- 0_0_1_1_0_0_0_1.wav
+ |-- 0_0_1_1_0_1_1_0.wav
+ |-- 0_0_1_1_1_0_0_0.wav
+ |-- 0_0_1_1_1_1_0_0.wav
+ |-- 0_1_0_0_0_1_0_0.wav
+ |-- 0_1_0_0_1_0_1_0.wav
+ |-- 0_1_0_1_0_0_0_0.wav
+ |-- 0_1_0_1_1_1_0_0.wav
+ |-- 0_1_1_0_0_1_1_1.wav
+ |-- 0_1_1_1_0_0_1_0.wav
+ |-- 0_1_1_1_1_0_1_0.wav
+ |-- 1_0_0_0_0_0_0_0.wav
+ |-- 1_0_0_0_0_0_1_1.wav
+ |-- 1_0_0_1_0_1_1_1.wav
+ |-- 1_0_1_1_0_1_1_1.wav
+ |-- 1_0_1_1_1_1_0_1.wav
+ |-- 1_1_0_0_0_1_1_1.wav
+ |-- 1_1_0_0_1_0_1_1.wav
+ |-- 1_1_0_1_0_1_0_0.wav
+ |-- 1_1_0_1_1_0_0_1.wav
+ |-- 1_1_0_1_1_1_1_0.wav
+ |-- 1_1_1_0_0_1_0_1.wav
+ |-- 1_1_1_0_1_0_1_0.wav
+ |-- 1_1_1_1_0_0_1_0.wav
+ |-- 1_1_1_1_1_0_0_0.wav
+ `-- 1_1_1_1_1_1_1_1.wav
+
+ 4 directories, 42 files
+
+.. code-block:: bash
+
+ $ soxi tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav
+
+ Input File : 'tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav'
+ Channels : 1
+ Sample Rate : 8000
+ Precision : 16-bit
+ Duration : 00:00:06.76 = 54080 samples ~ 507 CDDA sectors
+ File Size : 108k
+ Bit Rate : 128k
+ Sample Encoding: 16-bit Signed Integer PCM
+
+- ``0_0_1_0_1_0_0_1.wav``
+
+ 0 means No; 1 means Yes. No and Yes are not in English,
+ but in `Hebrew `_.
+ So this file contains ``NO NO YES NO YES NO NO YES``.
+
+Download kaldifeat
+~~~~~~~~~~~~~~~~~~
+
+`kaldifeat `_ is used for extracting
+features from a single or multiple sound files. Please refer to
+``_ to install ``kaldifeat`` first.
+
+Inference with a pre-trained model
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ $ cd egs/yesno/ASR
+ $ ./tdnn/pretrained.py --help
+
+shows the usage information of ``./tdnn/pretrained.py``.
+
+To decode a single file, we can use:
+
+.. code-block:: bash
+
+ ./tdnn/pretrained.py \
+ --checkpoint ./tmp/icefall_asr_yesno_tdnn/pretrained.pt \
+ --words-file ./tmp/icefall_asr_yesno_tdnn/lang_phone/words.txt \
+ --HLG ./tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt \
+ ./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav
+
+The output is:
+
+.. code-block::
+
+ 2021-08-24 12:22:51,621 INFO [pretrained.py:119] {'feature_dim': 23, 'num_classes': 4, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'checkpoint': './tmp/icefall_asr_yesno_tdnn/pretrained.pt', 'words_file': './tmp/icefall_asr_yesno_tdnn/lang_phone/words.txt', 'HLG': './tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt', 'sound_files': ['./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav']}
+ 2021-08-24 12:22:51,645 INFO [pretrained.py:125] device: cpu
+ 2021-08-24 12:22:51,645 INFO [pretrained.py:127] Creating model
+ 2021-08-24 12:22:51,650 INFO [pretrained.py:139] Loading HLG from ./tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt
+ 2021-08-24 12:22:51,651 INFO [pretrained.py:143] Constructing Fbank computer
+ 2021-08-24 12:22:51,652 INFO [pretrained.py:153] Reading sound files: ['./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav']
+ 2021-08-24 12:22:51,684 INFO [pretrained.py:159] Decoding started
+ 2021-08-24 12:22:51,708 INFO [pretrained.py:198]
+ ./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav:
+ NO NO YES NO YES NO NO YES
+
+
+ 2021-08-24 12:22:51,708 INFO [pretrained.py:200] Decoding Done
+
+You can see that for the sound file ``0_0_1_0_1_0_0_1.wav``, the decoding result is
+``NO NO YES NO YES NO NO YES``.
+
+To decode **multiple** files at the same time, you can use
+
+.. code-block:: bash
+
+ ./tdnn/pretrained.py \
+ --checkpoint ./tmp/icefall_asr_yesno_tdnn/pretrained.pt \
+ --words-file ./tmp/icefall_asr_yesno_tdnn/lang_phone/words.txt \
+ --HLG ./tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt \
+ ./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav \
+ ./tmp/icefall_asr_yesno_tdnn/test_waves/1_0_1_1_0_1_1_1.wav
+
+The decoding output is:
+
+.. code-block::
+
+ 2021-08-24 12:25:20,159 INFO [pretrained.py:119] {'feature_dim': 23, 'num_classes': 4, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'checkpoint': './tmp/icefall_asr_yesno_tdnn/pretrained.pt', 'words_file': './tmp/icefall_asr_yesno_tdnn/lang_phone/words.txt', 'HLG': './tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt', 'sound_files': ['./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav', './tmp/icefall_asr_yesno_tdnn/test_waves/1_0_1_1_0_1_1_1.wav']}
+ 2021-08-24 12:25:20,181 INFO [pretrained.py:125] device: cpu
+ 2021-08-24 12:25:20,181 INFO [pretrained.py:127] Creating model
+ 2021-08-24 12:25:20,185 INFO [pretrained.py:139] Loading HLG from ./tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt
+ 2021-08-24 12:25:20,186 INFO [pretrained.py:143] Constructing Fbank computer
+ 2021-08-24 12:25:20,187 INFO [pretrained.py:153] Reading sound files: ['./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav',
+ './tmp/icefall_asr_yesno_tdnn/test_waves/1_0_1_1_0_1_1_1.wav']
+ 2021-08-24 12:25:20,213 INFO [pretrained.py:159] Decoding started
+ 2021-08-24 12:25:20,287 INFO [pretrained.py:198]
+ ./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav:
+ NO NO YES NO YES NO NO YES
+
+ ./tmp/icefall_asr_yesno_tdnn/test_waves/1_0_1_1_0_1_1_1.wav:
+ YES NO YES YES NO YES YES YES
+
+ 2021-08-24 12:25:20,287 INFO [pretrained.py:200] Decoding Done
+
+You can see again that it decodes correctly.
+
+Colab notebook
+--------------
+
+We do provide a colab notebook for this recipe.
+
+|yesno colab notebook|
+
+.. |yesno colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg
+ :target: https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing
+
+
+**Congratulations!** You have finished the simplest speech recognition recipe in ``icefall``.
diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md
index 45c9ef4de..ae0c2684d 100644
--- a/egs/librispeech/ASR/README.md
+++ b/egs/librispeech/ASR/README.md
@@ -1,121 +1,3 @@
-Run `./prepare.sh` to prepare the data.
-
-Run `./xxx_train.py` (to be added) to train a model.
-
-## Conformer-CTC
-Results of the pre-trained model from
-``
-are given below
-
-### HLG - no LM rescoring
-
-(output beam size is 8)
-
-#### 1-best decoding
-
-```
-[test-clean-no_rescore] %WER 3.15% [1656 / 52576, 127 ins, 377 del, 1152 sub ]
-[test-other-no_rescore] %WER 7.03% [3682 / 52343, 220 ins, 1024 del, 2438 sub ]
-```
-
-#### n-best decoding
-
-For n=100,
-
-```
-[test-clean-no_rescore-100] %WER 3.15% [1656 / 52576, 127 ins, 377 del, 1152 sub ]
-[test-other-no_rescore-100] %WER 7.14% [3737 / 52343, 275 ins, 1020 del, 2442 sub ]
-```
-
-For n=200,
-
-```
-[test-clean-no_rescore-200] %WER 3.16% [1660 / 52576, 125 ins, 378 del, 1157 sub ]
-[test-other-no_rescore-200] %WER 7.04% [3684 / 52343, 228 ins, 1012 del, 2444 sub ]
-```
-
-### HLG - with LM rescoring
-
-#### Whole lattice rescoring
-
-```
-[test-clean-lm_scale_0.8] %WER 2.77% [1456 / 52576, 150 ins, 210 del, 1096 sub ]
-[test-other-lm_scale_0.8] %WER 6.23% [3262 / 52343, 246 ins, 635 del, 2381 sub ]
-```
-
-WERs of different LM scales are:
-
-```
-For test-clean, WER of different settings are:
-lm_scale_0.8 2.77 best for test-clean
-lm_scale_0.9 2.87
-lm_scale_1.0 3.06
-lm_scale_1.1 3.34
-lm_scale_1.2 3.71
-lm_scale_1.3 4.18
-lm_scale_1.4 4.8
-lm_scale_1.5 5.48
-lm_scale_1.6 6.08
-lm_scale_1.7 6.79
-lm_scale_1.8 7.49
-lm_scale_1.9 8.14
-lm_scale_2.0 8.82
-
-For test-other, WER of different settings are:
-lm_scale_0.8 6.23 best for test-other
-lm_scale_0.9 6.37
-lm_scale_1.0 6.62
-lm_scale_1.1 6.99
-lm_scale_1.2 7.46
-lm_scale_1.3 8.13
-lm_scale_1.4 8.84
-lm_scale_1.5 9.61
-lm_scale_1.6 10.32
-lm_scale_1.7 11.17
-lm_scale_1.8 12.12
-lm_scale_1.9 12.93
-lm_scale_2.0 13.77
-```
-
-#### n-best LM rescoring
-
-n = 100
-
-```
-[test-clean-lm_scale_0.8] %WER 2.79% [1469 / 52576, 149 ins, 212 del, 1108 sub ]
-[test-other-lm_scale_0.8] %WER 6.36% [3329 / 52343, 259 ins, 666 del, 2404 sub ]
-```
-
-WERs of different LM scales are:
-
-```
-For test-clean, WER of different settings are:
-lm_scale_0.8 2.79 best for test-clean
-lm_scale_0.9 2.89
-lm_scale_1.0 3.03
-lm_scale_1.1 3.28
-lm_scale_1.2 3.52
-lm_scale_1.3 3.78
-lm_scale_1.4 4.04
-lm_scale_1.5 4.24
-lm_scale_1.6 4.45
-lm_scale_1.7 4.58
-lm_scale_1.8 4.7
-lm_scale_1.9 4.8
-lm_scale_2.0 4.92
-For test-other, WER of different settings are:
-lm_scale_0.8 6.36 best for test-other
-lm_scale_0.9 6.45
-lm_scale_1.0 6.64
-lm_scale_1.1 6.92
-lm_scale_1.2 7.25
-lm_scale_1.3 7.59
-lm_scale_1.4 7.88
-lm_scale_1.5 8.13
-lm_scale_1.6 8.36
-lm_scale_1.7 8.54
-lm_scale_1.8 8.71
-lm_scale_1.9 8.88
-lm_scale_2.0 9.02
-```
+Please refer to
+for how to run models in this recipe.
diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md
new file mode 100644
index 000000000..d04e912bf
--- /dev/null
+++ b/egs/librispeech/ASR/RESULTS.md
@@ -0,0 +1,71 @@
+## Results
+
+### LibriSpeech BPE training results (Conformer-CTC)
+#### 2021-08-19
+(Wei Kang): Result of https://github.com/k2-fsa/icefall/pull/13
+
+TensorBoard log is available at https://tensorboard.dev/experiment/GnRzq8WWQW62dK4bklXBTg/#scalars
+
+Pretrained model is available at https://huggingface.co/pkufool/icefall_asr_librispeech_conformer_ctc
+
+The best decoding results (WER) are listed below, we got this results by averaging models from epoch 15 to 34, and using `attention-decoder` decoder with num_paths equals to 100.
+
+||test-clean|test-other|
+|--|--|--|
+|WER| 2.57% | 5.94% |
+
+To get more unique paths, we scaled the lattice.scores with 0.5 (see https://github.com/k2-fsa/icefall/pull/10#discussion_r690951662 for more details), we searched the lm_score_scale and attention_score_scale for best results, the scales that produced the WER above are also listed below.
+
+||lm_scale|attention_scale|
+|--|--|--|
+|test-clean|1.3|1.2|
+|test-other|1.2|1.1|
+
+You can use the following commands to reproduce our results:
+
+```bash
+git clone https://github.com/k2-fsa/icefall
+cd icefall
+
+# It was using ef233486, you may not need to switch to it
+# git checkout ef233486
+
+cd egs/librispeech/ASR
+./prepare.sh
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+python conformer_ctc/train.py --bucketing-sampler True \
+ --concatenate-cuts False \
+ --max-duration 200 \
+ --full-libri True \
+ --world-size 4
+
+python conformer_ctc/decode.py --lattice-score-scale 0.5 \
+ --epoch 34 \
+ --avg 20 \
+ --method attention-decoder \
+ --max-duration 20 \
+ --num-paths 100
+```
+
+### LibriSpeech training results (Tdnn-Lstm)
+#### 2021-08-24
+
+(Wei Kang): Result of phone based Tdnn-Lstm model.
+
+Icefall version: https://github.com/k2-fsa/icefall/commit/caa0b9e9425af27e0c6211048acb55a76ed5d315
+
+Pretrained model is available at https://huggingface.co/pkufool/icefall_asr_librispeech_tdnn-lstm_ctc
+
+The best decoding results (WER) are listed below, we got this results by averaging models from epoch 19 to 14, and using `whole-lattice-rescoring` decoding method.
+
+||test-clean|test-other|
+|--|--|--|
+|WER| 6.59% | 17.69% |
+
+We searched the lm_score_scale for best results, the scales that produced the WER above are also listed below.
+
+||lm_scale|
+|--|--|
+|test-clean|0.8|
+|test-other|0.9|
diff --git a/egs/librispeech/ASR/conformer_ctc/README.md b/egs/librispeech/ASR/conformer_ctc/README.md
new file mode 100644
index 000000000..23b51167b
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc/README.md
@@ -0,0 +1,3 @@
+Please visit
+
+for how to run this recipe.
diff --git a/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py b/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py
new file mode 120000
index 000000000..fa1b8cca3
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py
@@ -0,0 +1 @@
+../tdnn_lstm_ctc/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py
index a00664a99..b19b94db1 100644
--- a/egs/librispeech/ASR/conformer_ctc/conformer.py
+++ b/egs/librispeech/ASR/conformer_ctc/conformer.py
@@ -1,7 +1,20 @@
#!/usr/bin/env python3
-
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
-# Apache 2.0
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import math
import warnings
@@ -43,8 +56,6 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31,
normalize_before: bool = True,
vgg_frontend: bool = False,
- is_espnet_structure: bool = False,
- mmi_loss: bool = True,
use_feat_batchnorm: bool = False,
) -> None:
super(Conformer, self).__init__(
@@ -59,7 +70,6 @@ class Conformer(Transformer):
dropout=dropout,
normalize_before=normalize_before,
vgg_frontend=vgg_frontend,
- mmi_loss=mmi_loss,
use_feat_batchnorm=use_feat_batchnorm,
)
@@ -72,12 +82,10 @@ class Conformer(Transformer):
dropout,
cnn_module_kernel,
normalize_before,
- is_espnet_structure,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.normalize_before = normalize_before
- self.is_espnet_structure = is_espnet_structure
- if self.normalize_before and self.is_espnet_structure:
+ if self.normalize_before:
self.after_norm = nn.LayerNorm(d_model)
else:
# Note: TorchScript detects that self.after_norm could be used inside forward()
@@ -90,7 +98,7 @@ class Conformer(Transformer):
"""
Args:
x:
- The model input. Its shape is [N, T, C].
+ The model input. Its shape is (N, T, C).
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
@@ -112,7 +120,7 @@ class Conformer(Transformer):
mask = mask.to(x.device)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
- if self.normalize_before and self.is_espnet_structure:
+ if self.normalize_before:
x = self.after_norm(x)
return x, mask
@@ -146,11 +154,10 @@ class ConformerEncoderLayer(nn.Module):
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
- is_espnet_structure: bool = False,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(
- d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure
+ d_model, nhead, dropout=0.0
)
self.feed_forward = nn.Sequential(
@@ -396,7 +403,7 @@ class RelPositionalEncoding(torch.nn.Module):
:,
self.pe.size(1) // 2
- x.size(1)
- + 1 : self.pe.size(1) // 2
+ + 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1),
]
return self.dropout(x), self.dropout(pos_emb)
@@ -423,7 +430,6 @@ class RelPositionMultiheadAttention(nn.Module):
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
- is_espnet_structure: bool = False,
) -> None:
super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
@@ -446,8 +452,6 @@ class RelPositionMultiheadAttention(nn.Module):
self._reset_parameters()
- self.is_espnet_structure = is_espnet_structure
-
def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.constant_(self.in_proj.bias, 0.0)
@@ -677,9 +681,6 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:]
v = nn.functional.linear(value, _w, _b)
- if not self.is_espnet_structure:
- q = q * scaling
-
if attn_mask is not None:
assert (
attn_mask.dtype == torch.float32
@@ -772,14 +773,9 @@ class RelPositionMultiheadAttention(nn.Module):
) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd)
- if not self.is_espnet_structure:
- attn_output_weights = (
- matrix_ac + matrix_bd
- ) # (batch, head, time1, time2)
- else:
- attn_output_weights = (
- matrix_ac + matrix_bd
- ) * scaling # (batch, head, time1, time2)
+ attn_output_weights = (
+ matrix_ac + matrix_bd
+ ) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1
diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py
index 889a0a474..b5b41c82e 100755
--- a/egs/librispeech/ASR/conformer_ctc/decode.py
+++ b/egs/librispeech/ASR/conformer_ctc/decode.py
@@ -1,8 +1,20 @@
#!/usr/bin/env python3
-
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
-# (still working in progress)
import argparse
import logging
@@ -13,14 +25,15 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
-from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.decode import (
get_lattice,
nbest_decoding,
+ nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
@@ -32,6 +45,7 @@ from icefall.utils import (
get_texts,
setup_logger,
store_transcripts,
+ str2bool,
write_error_stats,
)
@@ -44,18 +58,76 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
- default=9,
+ default=34,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
- default=1,
+ default=20,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
+
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="attention-decoder",
+ help="""Decoding method.
+ Supported values are:
+ - (1) 1best. Extract the best path from the decoding lattice as the
+ decoding result.
+ - (2) nbest. Extract n paths from the decoding lattice; the path
+ with the highest score is the decoding result.
+ - (3) nbest-rescoring. Extract n paths from the decoding lattice,
+ rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
+ the highest score is the decoding result.
+ - (4) whole-lattice-rescoring. Rescore the decoding lattice with an
+ n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
+ is the decoding result.
+ - (5) attention-decoder. Extract n paths from the LM rescored
+ lattice, the path with the highest score is the decoding result.
+ - (6) nbest-oracle. Its WER is the lower bound of any n-best
+ rescoring method can achieve. Useful for debugging n-best
+ rescoring method.
+ """,
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=100,
+ help="""Number of paths for n-best based decoding method.
+ Used only when "method" is one of the following values:
+ nbest, nbest-rescoring, attention-decoder, and nbest-oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--lattice-score-scale",
+ type=float,
+ default=0.5,
+ help="""The scale to be applied to `lattice.scores`.
+ It's needed if you use any kinds of n-best based rescoring.
+ Used only when "method" is one of the following values:
+ nbest, nbest-rescoring, attention-decoder, and nbest-oracle
+ A smaller value results in more unique paths.
+ """,
+ )
+
+ parser.add_argument(
+ "--export",
+ type=str2bool,
+ default=False,
+ help="""When enabled, the averaged model is saved to
+ conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
+ pretrained.pt contains a dict {"model": model.state_dict()},
+ which can be loaded by `icefall.checkpoint.load_checkpoint()`.
+ """,
+ )
+
return parser
@@ -65,31 +137,20 @@ def get_params() -> AttributeDict:
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_bpe"),
"lm_dir": Path("data/lm"),
+ # parameters for conformer
+ "subsampling_factor": 4,
+ "vgg_frontend": False,
+ "use_feat_batchnorm": True,
"feature_dim": 80,
"nhead": 8,
"attention_dim": 512,
- "subsampling_factor": 4,
"num_decoder_layers": 6,
- "vgg_frontend": False,
- "is_espnet_structure": True,
- "mmi_loss": False,
- "use_feat_batchnorm": True,
+ # parameters for decoding
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
- # Possible values for method:
- # - 1best
- # - nbest
- # - nbest-rescoring
- # - whole-lattice-rescoring
- # - attention-decoder
- # "method": "whole-lattice-rescoring",
- "method": "attention-decoder",
- # num_paths is used when method is "nbest", "nbest-rescoring",
- # and attention-decoder
- "num_paths": 100,
}
)
return params
@@ -100,7 +161,7 @@ def decode_one_batch(
model: nn.Module,
HLG: k2.Fsa,
batch: dict,
- lexicon: Lexicon,
+ word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
@@ -134,8 +195,8 @@ def decode_one_batch(
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
- lexicon:
- It contains word symbol table.
+ word_table:
+ The word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
@@ -152,12 +213,12 @@ def decode_one_batch(
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
- # at entry, feature is [N, T, C]
+ # at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
- # nnet_output is [N, T, C]
+ # nnet_output is (N, T, C)
supervision_segments = torch.stack(
(
@@ -179,6 +240,24 @@ def decode_one_batch(
subsampling_factor=params.subsampling_factor,
)
+ if params.method == "nbest-oracle":
+ # Note: You can also pass rescored lattices to it.
+ # We choose the HLG decoded lattice for speed reasons
+ # as HLG decoding is faster and the oracle WER
+ # is only slightly worse than that of rescored lattices.
+ best_path = nbest_oracle(
+ lattice=lattice,
+ num_paths=params.num_paths,
+ ref_texts=supervisions["text"],
+ word_table=word_table,
+ lattice_score_scale=params.lattice_score_scale,
+ oov="",
+ )
+ hyps = get_texts(best_path)
+ hyps = [[word_table[i] for i in ids] for ids in hyps]
+ key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa
+ return {key: hyps}
+
if params.method in ["1best", "nbest"]:
if params.method == "1best":
best_path = one_best_decoding(
@@ -190,11 +269,12 @@ def decode_one_batch(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
+ lattice_score_scale=params.lattice_score_scale,
)
- key = f"no_rescore-{params.num_paths}"
+ key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path)
- hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
+ hyps = [[word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.method in [
@@ -203,7 +283,8 @@ def decode_one_batch(
"attention-decoder",
]
- lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
+ lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
+ lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.method == "nbest-rescoring":
@@ -212,16 +293,23 @@ def decode_one_batch(
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
+ lattice_score_scale=params.lattice_score_scale,
)
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
- lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
+ lattice=lattice,
+ G_with_epsilon_loops=G,
+ lm_scale_list=lm_scale_list,
)
elif params.method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
- lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
+ lattice=lattice,
+ G_with_epsilon_loops=G,
+ lm_scale_list=None,
)
+ # TODO: pass `lattice` instead of `rescored_lattice` to
+ # `rescore_with_attention_decoder`
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
@@ -231,15 +319,20 @@ def decode_one_batch(
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
+ lattice_score_scale=params.lattice_score_scale,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
ans = dict()
- for lm_scale_str, best_path in best_path_dict.items():
- hyps = get_texts(best_path)
- hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
- ans[lm_scale_str] = hyps
+ if best_path_dict is not None:
+ for lm_scale_str, best_path in best_path_dict.items():
+ hyps = get_texts(best_path)
+ hyps = [[word_table[i] for i in ids] for ids in hyps]
+ ans[lm_scale_str] = hyps
+ else:
+ for lm_scale in lm_scale_list:
+ ans[lm_scale_str] = [[] * lattice.shape[0]]
return ans
@@ -248,7 +341,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
HLG: k2.Fsa,
- lexicon: Lexicon,
+ word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
@@ -264,8 +357,8 @@ def decode_dataset(
The neural model.
HLG:
The decoding graph.
- lexicon:
- It contains word symbol table.
+ word_table:
+ It is the word symbol table.
sos_id:
The token ID for SOS.
eos_id:
@@ -284,7 +377,11 @@ def decode_dataset(
results = []
num_cuts = 0
- tot_num_cuts = len(dl.dataset.cuts)
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@@ -295,7 +392,7 @@ def decode_dataset(
model=model,
HLG=HLG,
batch=batch,
- lexicon=lexicon,
+ word_table=word_table,
G=G,
sos_id=sos_id,
eos_id=eos_id,
@@ -313,10 +410,10 @@ def decode_dataset(
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
logging.info(
- f"batch {batch_idx}, cuts processed until now is "
- f"{num_cuts}/{tot_num_cuts} "
- f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
+ f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
@@ -376,7 +473,7 @@ def main():
params = get_params()
params.update(vars(args))
- setup_logger(f"{params.exp_dir}/log/log-decode")
+ setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
logging.info("Decoding started")
logging.info(params)
@@ -399,7 +496,9 @@ def main():
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
- HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
+ HLG = k2.Fsa.from_dict(
+ torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
+ )
HLG = HLG.to(device)
assert HLG.requires_grad is False
@@ -430,7 +529,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
- d = torch.load(params.lm_dir / "G_4_gram.pt")
+ d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
G = k2.Fsa.from_dict(d).to(device)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
@@ -454,8 +553,6 @@ def main():
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
- is_espnet_structure=params.is_espnet_structure,
- mmi_loss=params.mmi_loss,
use_feat_batchnorm=params.use_feat_batchnorm,
)
@@ -470,6 +567,13 @@ def main():
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
+ if params.export:
+ logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
+ torch.save(
+ {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+ )
+ return
+
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
@@ -489,7 +593,7 @@ def main():
params=params,
model=model,
HLG=HLG,
- lexicon=lexicon,
+ word_table=lexicon.word_table,
G=G,
sos_id=sos_id,
eos_id=eos_id,
diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py
new file mode 100755
index 000000000..c924b87bb
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py
@@ -0,0 +1,364 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import torch
+import torchaudio
+from conformer import Conformer
+from torch.nn.utils.rnn import pad_sequence
+
+from icefall.decode import (
+ get_lattice,
+ one_best_decoding,
+ rescore_with_attention_decoder,
+ rescore_with_whole_lattice,
+)
+from icefall.utils import AttributeDict, get_texts
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ required=True,
+ help="Path to the checkpoint. "
+ "The checkpoint is assumed to be saved by "
+ "icefall.checkpoint.save_checkpoint().",
+ )
+
+ parser.add_argument(
+ "--words-file",
+ type=str,
+ required=True,
+ help="Path to words.txt",
+ )
+
+ parser.add_argument(
+ "--HLG", type=str, required=True, help="Path to HLG.pt."
+ )
+
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="1best",
+ help="""Decoding method.
+ Possible values are:
+ (1) 1best - Use the best path as decoding output. Only
+ the transformer encoder output is used for decoding.
+ We call it HLG decoding.
+ (2) whole-lattice-rescoring - Use an LM to rescore the
+ decoding lattice and then use 1best to decode the
+ rescored lattice.
+ We call it HLG decoding + n-gram LM rescoring.
+ (3) attention-decoder - Extract n paths from the rescored
+ lattice and use the transformer attention decoder for
+ rescoring.
+ We call it HLG decoding + n-gram LM rescoring + attention
+ decoder rescoring.
+ """,
+ )
+
+ parser.add_argument(
+ "--G",
+ type=str,
+ help="""An LM for rescoring.
+ Used only when method is
+ whole-lattice-rescoring or attention-decoder.
+ It's usually a 4-gram LM.
+ """,
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=100,
+ help="""
+ Used only when method is attention-decoder.
+ It specifies the size of n-best list.""",
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=1.3,
+ help="""
+ Used only when method is whole-lattice-rescoring and attention-decoder.
+ It specifies the scale for n-gram LM scores.
+ (Note: You need to tune it on a dataset.)
+ """,
+ )
+
+ parser.add_argument(
+ "--attention-decoder-scale",
+ type=float,
+ default=1.2,
+ help="""
+ Used only when method is attention-decoder.
+ It specifies the scale for attention decoder scores.
+ (Note: You need to tune it on a dataset.)
+ """,
+ )
+
+ parser.add_argument(
+ "--lattice-score-scale",
+ type=float,
+ default=0.5,
+ help="""
+ Used only when method is attention-decoder.
+ It specifies the scale for lattice.scores when
+ extracting n-best lists. A smaller value results in
+ more unique number of paths with the risk of missing
+ the best path.
+ """,
+ )
+
+ parser.add_argument(
+ "--sos-id",
+ type=float,
+ default=1,
+ help="""
+ Used only when method is attention-decoder.
+ It specifies ID for the SOS token.
+ """,
+ )
+
+ parser.add_argument(
+ "--eos-id",
+ type=float,
+ default=1,
+ help="""
+ Used only when method is attention-decoder.
+ It specifies ID for the EOS token.
+ """,
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ params = AttributeDict(
+ {
+ "sample_rate": 16000,
+ # parameters for conformer
+ "subsampling_factor": 4,
+ "vgg_frontend": False,
+ "use_feat_batchnorm": True,
+ "feature_dim": 80,
+ "nhead": 8,
+ "num_classes": 5000,
+ "attention_dim": 512,
+ "num_decoder_layers": 6,
+ # parameters for decoding
+ "search_beam": 20,
+ "output_beam": 8,
+ "min_active_states": 30,
+ "max_active_states": 10000,
+ "use_double_scores": True,
+ }
+ )
+ return params
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert sample_rate == expected_sample_rate, (
+ f"expected sample rate: {expected_sample_rate}. "
+ f"Given: {sample_rate}"
+ )
+ # We use only the first channel
+ ans.append(wave[0])
+ return ans
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = get_params()
+ params.update(vars(args))
+ logging.info(f"{params}")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ logging.info("Creating model")
+ model = Conformer(
+ num_features=params.feature_dim,
+ nhead=params.nhead,
+ d_model=params.attention_dim,
+ num_classes=params.num_classes,
+ subsampling_factor=params.subsampling_factor,
+ num_decoder_layers=params.num_decoder_layers,
+ vgg_frontend=params.vgg_frontend,
+ use_feat_batchnorm=params.use_feat_batchnorm,
+ )
+
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ model.load_state_dict(checkpoint["model"])
+ model.to(device)
+ model.eval()
+
+ logging.info(f"Loading HLG from {params.HLG}")
+ HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+ HLG = HLG.to(device)
+ if not hasattr(HLG, "lm_scores"):
+ # For whole-lattice-rescoring and attention-decoder
+ HLG.lm_scores = HLG.scores.clone()
+
+ if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
+ logging.info(f"Loading G from {params.G}")
+ G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+ # Add epsilon self-loops to G as we will compose
+ # it with the whole lattice later
+ G = G.to(device)
+ G = k2.add_epsilon_self_loops(G)
+ G = k2.arc_sort(G)
+ G.lm_scores = G.scores.clone()
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = params.sample_rate
+ opts.mel_opts.num_bins = params.feature_dim
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {params.sound_files}")
+ waves = read_sound_files(
+ filenames=params.sound_files, expected_sample_rate=params.sample_rate
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+
+ features = pad_sequence(
+ features, batch_first=True, padding_value=math.log(1e-10)
+ )
+
+ # Note: We don't use key padding mask for attention during decoding
+ with torch.no_grad():
+ nnet_output, memory, memory_key_padding_mask = model(features)
+
+ batch_size = nnet_output.shape[0]
+ supervision_segments = torch.tensor(
+ [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+ dtype=torch.int32,
+ )
+
+ lattice = get_lattice(
+ nnet_output=nnet_output,
+ HLG=HLG,
+ supervision_segments=supervision_segments,
+ search_beam=params.search_beam,
+ output_beam=params.output_beam,
+ min_active_states=params.min_active_states,
+ max_active_states=params.max_active_states,
+ subsampling_factor=params.subsampling_factor,
+ )
+
+ if params.method == "1best":
+ logging.info("Use HLG decoding")
+ best_path = one_best_decoding(
+ lattice=lattice, use_double_scores=params.use_double_scores
+ )
+ elif params.method == "whole-lattice-rescoring":
+ logging.info("Use HLG decoding + LM rescoring")
+ best_path_dict = rescore_with_whole_lattice(
+ lattice=lattice,
+ G_with_epsilon_loops=G,
+ lm_scale_list=[params.ngram_lm_scale],
+ )
+ best_path = next(iter(best_path_dict.values()))
+ elif params.method == "attention-decoder":
+ logging.info("Use HLG + LM rescoring + attention decoder rescoring")
+ rescored_lattice = rescore_with_whole_lattice(
+ lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
+ )
+ best_path_dict = rescore_with_attention_decoder(
+ lattice=rescored_lattice,
+ num_paths=params.num_paths,
+ model=model,
+ memory=memory,
+ memory_key_padding_mask=memory_key_padding_mask,
+ sos_id=params.sos_id,
+ eos_id=params.eos_id,
+ lattice_score_scale=params.lattice_score_scale,
+ ngram_lm_scale=params.ngram_lm_scale,
+ attention_scale=params.attention_decoder_scale,
+ )
+ best_path = next(iter(best_path_dict.values()))
+
+ hyps = get_texts(best_path)
+ word_sym_table = k2.SymbolTable.from_file(params.words_file)
+ hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
+
+ s = "\n"
+ for filename, hyp in zip(params.sound_files, hyps):
+ words = " ".join(hyp)
+ s += f"{filename}:\n{words}\n\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py
index 5c3e1222e..542fb0364 100644
--- a/egs/librispeech/ASR/conformer_ctc/subsampling.py
+++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py
@@ -1,3 +1,20 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
import torch
import torch.nn as nn
@@ -5,8 +22,8 @@ import torch.nn as nn
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
- Convert an input of shape [N, T, idim] to an output
- with shape [N, T', odim], where
+ Convert an input of shape (N, T, idim) to an output
+ with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
It is based on
@@ -17,10 +34,10 @@ class Conv2dSubsampling(nn.Module):
"""
Args:
idim:
- Input dim. The input shape is [N, T, idim].
+ Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7
odim:
- Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
+ Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
"""
assert idim >= 7
super().__init__()
@@ -41,18 +58,18 @@ class Conv2dSubsampling(nn.Module):
Args:
x:
- Its shape is [N, T, idim].
+ Its shape is (N, T, idim).
Returns:
- Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
+ Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
- # On entry, x is [N, T, idim]
- x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W]
+ # On entry, x is (N, T, idim)
+ x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
- # Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2]
+ # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
- # Now x is of shape [N, ((T-1)//2 - 1))//2, odim]
+ # Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
return x
@@ -63,8 +80,8 @@ class VggSubsampling(nn.Module):
This paper is not 100% explicit so I am guessing to some extent,
and trying to compare with other VGG implementations.
- Convert an input of shape [N, T, idim] to an output
- with shape [N, T', odim], where
+ Convert an input of shape (N, T, idim) to an output
+ with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
"""
@@ -76,10 +93,10 @@ class VggSubsampling(nn.Module):
Args:
idim:
- Input dim. The input shape is [N, T, idim].
+ Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7
odim:
- Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
+ Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
"""
super().__init__()
@@ -132,10 +149,10 @@ class VggSubsampling(nn.Module):
Args:
x:
- Its shape is [N, T, idim].
+ Its shape is (N, T, idim).
Returns:
- Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
+ Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
x = x.unsqueeze(1)
x = self.layers(x)
diff --git a/egs/librispeech/ASR/conformer_ctc/test_subsampling.py b/egs/librispeech/ASR/conformer_ctc/test_subsampling.py
index 937845d77..81fa234dd 100755
--- a/egs/librispeech/ASR/conformer_ctc/test_subsampling.py
+++ b/egs/librispeech/ASR/conformer_ctc/test_subsampling.py
@@ -1,8 +1,23 @@
#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
-from subsampling import Conv2dSubsampling
-from subsampling import VggSubsampling
import torch
+from subsampling import Conv2dSubsampling, VggSubsampling
def test_conv2d_subsampling():
diff --git a/egs/librispeech/ASR/conformer_ctc/test_transformer.py b/egs/librispeech/ASR/conformer_ctc/test_transformer.py
index 08e680607..667057c51 100644
--- a/egs/librispeech/ASR/conformer_ctc/test_transformer.py
+++ b/egs/librispeech/ASR/conformer_ctc/test_transformer.py
@@ -1,17 +1,32 @@
#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import torch
+from torch.nn.utils.rnn import pad_sequence
from transformer import (
Transformer,
+ add_eos,
+ add_sos,
+ decoder_padding_mask,
encoder_padding_mask,
generate_square_subsequent_mask,
- decoder_padding_mask,
- add_sos,
- add_eos,
)
-from torch.nn.utils.rnn import pad_sequence
-
def test_encoder_padding_mask():
supervisions = {
diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py
index 552db81ec..80b2d924a 100755
--- a/egs/librispeech/ASR/conformer_ctc/train.py
+++ b/egs/librispeech/ASR/conformer_ctc/train.py
@@ -1,6 +1,21 @@
#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
-# This is just at the very beginning ...
import argparse
import logging
@@ -13,16 +28,17 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
-from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.dist import cleanup_dist, setup_dist
from icefall.lexicon import Lexicon
from icefall.utils import (
@@ -59,9 +75,23 @@ def get_parser():
help="Should various information be logged in tensorboard.",
)
- # TODO: add extra arguments and support DDP training.
- # Currently, only single GPU training is implemented. Will add
- # DDP training once single GPU training is finished.
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=35,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=0,
+ help="""Resume training from from this epoch.
+ If it is positive, it will load checkpoint from
+ conformer_ctc/exp/epoch-{start_epoch-1}.pt
+ """,
+ )
+
return parser
@@ -82,20 +112,6 @@ def get_params() -> AttributeDict:
- lang_dir: It contains language related input files such as
"lexicon.txt"
- - lr: It specifies the initial learning rate
-
- - feature_dim: The model input dim. It has to match the one used
- in computing features.
-
- - weight_decay: The weight_decay for the optimizer.
-
- - subsampling_factor: The subsampling factor for the model.
-
- - start_epoch: If it is not zero, load checkpoint `start_epoch-1`
- and continue training from that checkpoint.
-
- - num_epochs: Number of epochs to train.
-
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
@@ -114,42 +130,62 @@ def get_params() -> AttributeDict:
- log_interval: Print training loss if batch_idx % log_interval` is 0
- - valid_interval: Run validation if batch_idx % valid_interval` is 0
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - use_feat_batchnorm: Whether to do batch normalization for the
+ input features.
+
+ - attention_dim: Hidden dim for multi-head attention model.
+
+ - head: Number of heads of multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
- beam_size: It is used in k2.ctc_loss
- reduction: It is used in k2.ctc_loss
- use_double_scores: It is used in k2.ctc_loss
+
+ - weight_decay: The weight_decay for the optimizer.
+
+ - lr_factor: The lr_factor for Noam optimizer.
+
+ - warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_bpe"),
- "feature_dim": 80,
- "weight_decay": 0.0,
- "subsampling_factor": 4,
- "start_epoch": 0,
- "num_epochs": 50,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 10,
+ "reset_interval": 200,
"valid_interval": 3000,
- "beam_size": 10,
- "reduction": "sum",
- "use_double_scores": True,
- #
- "accum_grad": 1,
- "att_rate": 0.7,
+ # parameters for conformer
+ "feature_dim": 80,
+ "subsampling_factor": 4,
+ "use_feat_batchnorm": True,
"attention_dim": 512,
"nhead": 8,
"num_decoder_layers": 6,
- "is_espnet_structure": True,
- "mmi_loss": False,
- "use_feat_batchnorm": True,
+ # parameters for loss
+ "beam_size": 10,
+ "reduction": "sum",
+ "use_double_scores": True,
+ "att_rate": 0.7,
+ # parameters for Noam
+ "weight_decay": 1e-6,
"lr_factor": 5.0,
"warm_step": 80000,
}
@@ -274,14 +310,14 @@ def compute_loss(
"""
device = graph_compiler.device
feature = batch["inputs"]
- # at entry, feature is [N, T, C]
+ # at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
- # nnet_output is [N, T, C]
+ # nnet_output is (N, T, C)
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
@@ -440,6 +476,8 @@ def train_one_epoch(
tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches
+ params.tot_loss = 0.0
+ params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@@ -457,6 +495,7 @@ def train_one_epoch(
optimizer.zero_grad()
loss.backward()
+ clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
loss_cpu = loss.detach().cpu().item()
@@ -468,6 +507,9 @@ def train_one_epoch(
tot_ctc_loss += ctc_loss_cpu
tot_att_loss += att_loss_cpu
+ params.tot_frames += params.train_frames
+ params.tot_loss += loss_cpu
+
tot_avg_loss = tot_loss / tot_frames
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
tot_avg_att_loss = tot_att_loss / tot_frames
@@ -516,6 +558,12 @@ def train_one_epoch(
tot_avg_loss,
params.batch_idx_train,
)
+ if batch_idx > 0 and batch_idx % params.reset_interval == 0:
+ tot_loss = 0.0 # sum of losses over all batches
+ tot_ctc_loss = 0.0
+ tot_att_loss = 0.0
+
+ tot_frames = 0.0 # sum of frames over all batches
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss(
@@ -551,7 +599,7 @@ def train_one_epoch(
params.batch_idx_train,
)
- params.train_loss = tot_loss / tot_frames
+ params.train_loss = params.tot_loss / params.tot_frames
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
@@ -610,8 +658,6 @@ def run(rank, world_size, args):
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,
- is_espnet_structure=params.is_espnet_structure,
- mmi_loss=params.mmi_loss,
use_feat_batchnorm=params.use_feat_batchnorm,
)
diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py
index 086201267..f1d7cbbbc 100644
--- a/egs/librispeech/ASR/conformer_ctc/transformer.py
+++ b/egs/librispeech/ASR/conformer_ctc/transformer.py
@@ -1,15 +1,26 @@
-# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
-# Apache 2.0
+# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import math
from typing import Dict, List, Optional, Tuple
-import k2
import torch
import torch.nn as nn
from subsampling import Conv2dSubsampling, VggSubsampling
-
-from icefall.utils import get_texts
from torch.nn.utils.rnn import pad_sequence
# Note: TorchScript requires Dict/List/etc. to be fully typed.
@@ -30,7 +41,6 @@ class Transformer(nn.Module):
dropout: float = 0.1,
normalize_before: bool = True,
vgg_frontend: bool = False,
- mmi_loss: bool = True,
use_feat_batchnorm: bool = False,
) -> None:
"""
@@ -59,7 +69,6 @@ class Transformer(nn.Module):
If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend:
True to use vgg style frontend for subsampling.
- mmi_loss:
use_feat_batchnorm:
True to use batchnorm for the input layer.
"""
@@ -74,8 +83,8 @@ class Transformer(nn.Module):
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
- # self.encoder_embed converts the input of shape [N, T, num_classes]
- # to the shape [N, T//subsampling_factor, d_model].
+ # self.encoder_embed converts the input of shape (N, T, num_classes)
+ # to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_classes -> d_model
@@ -108,14 +117,9 @@ class Transformer(nn.Module):
self.encoder_output_layer = nn.Linear(d_model, num_classes)
if num_decoder_layers > 0:
- if mmi_loss:
- self.decoder_num_class = (
- self.num_classes + 1
- ) # +1 for the sos/eos symbol
- else:
- self.decoder_num_class = (
- self.num_classes
- ) # bpe model already has sos/eos symbol
+ self.decoder_num_class = (
+ self.num_classes
+ ) # bpe model already has sos/eos symbol
self.decoder_embed = nn.Embedding(
num_embeddings=self.decoder_num_class, embedding_dim=d_model
@@ -155,7 +159,7 @@ class Transformer(nn.Module):
"""
Args:
x:
- The input tensor. Its shape is [N, T, C].
+ The input tensor. Its shape is (N, T, C).
supervision:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
@@ -164,17 +168,17 @@ class Transformer(nn.Module):
Returns:
Return a tuple containing 3 tensors:
- - CTC output for ctc decoding. Its shape is [N, T, C]
- - Encoder output with shape [T, N, C]. It can be used as key and
+ - CTC output for ctc decoding. Its shape is (N, T, C)
+ - Encoder output with shape (T, N, C). It can be used as key and
value for the decoder.
- Encoder output padding mask. It can be used as
- memory_key_padding_mask for the decoder. Its shape is [N, T].
+ memory_key_padding_mask for the decoder. Its shape is (N, T).
It is None if `supervision` is None.
"""
if self.use_feat_batchnorm:
- x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
+ x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
- x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
+ x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision
)
@@ -188,7 +192,7 @@ class Transformer(nn.Module):
Args:
x:
- The model input. Its shape is [N, T, C].
+ The model input. Its shape is (N, T, C).
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
@@ -199,8 +203,8 @@ class Transformer(nn.Module):
padding mask for the decoder.
Returns:
Return a tuple with two tensors:
- - The encoder output, with shape [T, N, C]
- - encoder padding mask, with shape [N, T].
+ - The encoder output, with shape (T, N, C)
+ - encoder padding mask, with shape (N, T).
The mask is None if `supervisions` is None.
It is used as memory key padding mask in the decoder.
"""
@@ -218,11 +222,11 @@ class Transformer(nn.Module):
Args:
x:
The output tensor from the transformer encoder.
- Its shape is [T, N, C]
+ Its shape is (T, N, C)
Returns:
Return a tensor that can be used for CTC decoding.
- Its shape is [N, T, C]
+ Its shape is (N, T, C)
"""
x = self.encoder_output_layer(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@@ -240,7 +244,7 @@ class Transformer(nn.Module):
"""
Args:
memory:
- It's the output of the encoder with shape [T, N, C]
+ It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask:
The padding mask from the encoder.
token_ids:
@@ -305,7 +309,7 @@ class Transformer(nn.Module):
"""
Args:
memory:
- It's the output of the encoder with shape [T, N, C]
+ It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask:
The padding mask from the encoder.
token_ids:
@@ -338,6 +342,9 @@ class Transformer(nn.Module):
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
+ # TODO: Use length information to create the decoder padding mask
+ # We set the first column to False since the first column in ys_in_pad
+ # contains sos_id, which is the same as eos_id in our current setting.
tgt_key_padding_mask[:, 0] = False
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
@@ -643,13 +650,13 @@ class PositionalEncoding(nn.Module):
def extend_pe(self, x: torch.Tensor) -> None:
"""Extend the time t in the positional encoding if required.
- The shape of `self.pe` is [1, T1, d_model]. The shape of the input x
- is [N, T, d_model]. If T > T1, then we change the shape of self.pe
- to [N, T, d_model]. Otherwise, nothing is done.
+ The shape of `self.pe` is (1, T1, d_model). The shape of the input x
+ is (N, T, d_model). If T > T1, then we change the shape of self.pe
+ to (N, T, d_model). Otherwise, nothing is done.
Args:
x:
- It is a tensor of shape [N, T, C].
+ It is a tensor of shape (N, T, C).
Returns:
Return None.
"""
@@ -667,7 +674,7 @@ class PositionalEncoding(nn.Module):
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
- # Now pe is of shape [1, T, d_model], where T is x.size(1)
+ # Now pe is of shape (1, T, d_model), where T is x.size(1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -676,10 +683,10 @@ class PositionalEncoding(nn.Module):
Args:
x:
- Its shape is [N, T, C]
+ Its shape is (N, T, C)
Returns:
- Return a tensor of shape [N, T, C]
+ Return a tensor of shape (N, T, C)
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1), :]
@@ -775,7 +782,8 @@ class Noam(object):
class LabelSmoothingLoss(nn.Module):
"""
- Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w)
+ Label-smoothing loss. KL-divergence between
+ q_{smoothed ground truth prob.}(w)
and p_{prob. computed by model}(w) is minimized.
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa
@@ -860,7 +868,8 @@ def encoder_padding_mask(
frames, before subsampling)
Returns:
- Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices.
+ Tensor: Mask tensor of dimension (batch_size, input_length),
+ True denote the masked indices.
"""
if supervisions is None:
return None
diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py
index 9f28bb74d..098d5d6a3 100755
--- a/egs/librispeech/ASR/local/compile_hlg.py
+++ b/egs/librispeech/ASR/local/compile_hlg.py
@@ -1,4 +1,20 @@
#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
This script takes as input lang_dir and generates HLG from
@@ -86,14 +102,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
LG.labels[LG.labels >= first_token_disambig_id] = 0
- assert isinstance(LG.aux_labels, k2.RaggedInt)
- LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0
+ assert isinstance(LG.aux_labels, k2.RaggedTensor)
+ LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
- LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)
+ LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py
index d81096070..b26034eb2 100755
--- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py
+++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py
@@ -1,8 +1,24 @@
#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
This file computes fbank features of the LibriSpeech dataset.
-Its looks for manifests in the directory data/manifests.
+It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
@@ -17,8 +33,9 @@ from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
-# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and
-# slow things down. Do this outside of main() in case it needs to take effect
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
@@ -53,7 +70,8 @@ def compute_fbank_librispeech():
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
- recordings=m["recordings"], supervisions=m["supervisions"],
+ recordings=m["recordings"],
+ supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (
diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py
index 0fc515d8c..d44524e70 100755
--- a/egs/librispeech/ASR/local/compute_fbank_musan.py
+++ b/egs/librispeech/ASR/local/compute_fbank_musan.py
@@ -1,8 +1,24 @@
#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
This file computes fbank features of the musan dataset.
-Its looks for manifests in the directory data/manifests.
+It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
@@ -17,8 +33,9 @@ from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
-# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and
-# slow things down. Do this outside of main() in case it needs to take effect
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py
index 5c9e2a675..94d23afed 100755
--- a/egs/librispeech/ASR/local/download_lm.py
+++ b/egs/librispeech/ASR/local/download_lm.py
@@ -1,6 +1,21 @@
#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
-# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
"""
This file downloads the following LibriSpeech LM files:
diff --git a/egs/librispeech/ASR/local/prepare_lang.py b/egs/librispeech/ASR/local/prepare_lang.py
index f7fde7796..0880019b3 100755
--- a/egs/librispeech/ASR/local/prepare_lang.py
+++ b/egs/librispeech/ASR/local/prepare_lang.py
@@ -1,6 +1,20 @@
#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
-# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
"""
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py
index 68b8db966..39d347661 100755
--- a/egs/librispeech/ASR/local/prepare_lang_bpe.py
+++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py
@@ -1,4 +1,20 @@
#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
diff --git a/egs/librispeech/ASR/local/test_prepare_lang.py b/egs/librispeech/ASR/local/test_prepare_lang.py
index 23ab53c7d..d4cf62bba 100755
--- a/egs/librispeech/ASR/local/test_prepare_lang.py
+++ b/egs/librispeech/ASR/local/test_prepare_lang.py
@@ -1,4 +1,20 @@
#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py
index 9872a7c6a..3c3ecdcae 100755
--- a/egs/librispeech/ASR/local/train_bpe_model.py
+++ b/egs/librispeech/ASR/local/train_bpe_model.py
@@ -1,4 +1,19 @@
#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
# You can install sentencepiece via:
#
diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh
index 6479973bf..564f0d067 100755
--- a/egs/librispeech/ASR/prepare.sh
+++ b/egs/librispeech/ASR/prepare.sh
@@ -25,7 +25,7 @@ stop_stage=100
# - librispeech-vocab.txt
# - librispeech-lexicon.txt
#
-# - $do_dir/musan
+# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
@@ -37,12 +37,13 @@ dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
-# It will generate data/lang_bpe_500, data/lang_bpe_1000,
-# and data/lang_bpe_5000.
+# It will generate data/lang_bpe_xxx,
+# data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=(
- 500
- 1000
5000
+ 2000
+ 1000
+ 500
)
# All files generated by this script are saved in "data".
@@ -59,6 +60,7 @@ log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "stage -1: Download LM"
+ [ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm
./local/download_lm.py --out-dir=$dl_dir/lm
fi
@@ -139,9 +141,9 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
if [ ! -f $lang_dir/train.txt ]; then
log "Generate data for BPE training"
files=$(
- find "data/LibriSpeech/train-clean-100" -name "*.trans.txt"
- find "data/LibriSpeech/train-clean-360" -name "*.trans.txt"
- find "data/LibriSpeech/train-other-500" -name "*.trans.txt"
+ find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
+ find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt"
+ find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt"
)
for f in ${files[@]}; do
cat $f | cut -d " " -f 2-
@@ -223,3 +225,5 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
./local/compile_hlg.py --lang-dir $lang_dir
done
fi
+
+cd data && ln -sfv lang_bpe_5000 lang_bpe
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/README.md b/egs/librispeech/ASR/tdnn_lstm_ctc/README.md
index 401f3e319..94d4ed6a3 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/README.md
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/README.md
@@ -1,22 +1,4 @@
-## (To be filled in)
-It will contain:
-
-- How to run
-- WERs
-
-```bash
-cd $PWD/..
-
-./prepare.sh
-
-./tdnn_lstm_ctc/train.py
-```
-
-If you have 4 GPUs and want to use GPU 1 and GPU 3 for DDP training,
-you can do the following:
-
-```
-export CUDA_VISIBLE_DEVICES="1,3"
-./tdnn_lstm_ctc/train.py --world-size=2
-```
+Please visit
+
+for how to run this recipe.
diff --git a/icefall/dataset/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
similarity index 65%
rename from icefall/dataset/asr_datamodule.py
rename to egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index aae7af9ce..8290e71d1 100644
--- a/icefall/dataset/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -1,14 +1,33 @@
+# Copyright 2021 Piotr Żelasko
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
import argparse
import logging
+from functools import lru_cache
from pathlib import Path
from typing import List, Union
-from lhotse import Fbank, FbankConfig, load_manifest
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
BucketingSampler,
CutConcatenate,
CutMix,
K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
@@ -19,9 +38,9 @@ from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
-class AsrDataModule(DataModule):
+class LibriSpeechAsrDataModule(DataModule):
"""
- DataModule for K2 ASR experiments.
+ DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
@@ -47,6 +66,13 @@ class AsrDataModule(DataModule):
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
+ group.add_argument(
+ "--full-libri",
+ type=str2bool,
+ default=True,
+ help="When enabled, use 960h LibriSpeech. "
+ "Otherwise, use 100h subset.",
+ )
group.add_argument(
"--feature-dir",
type=Path,
@@ -56,14 +82,14 @@ class AsrDataModule(DataModule):
group.add_argument(
"--max-duration",
type=int,
- default=500.0,
+ default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
- default=False,
+ default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
@@ -77,7 +103,7 @@ class AsrDataModule(DataModule):
group.add_argument(
"--concatenate-cuts",
type=str2bool,
- default=True,
+ default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
@@ -104,6 +130,29 @@ class AsrDataModule(DataModule):
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
def train_dataloaders(self) -> DataLoader:
logging.info("About to get train cuts")
@@ -138,9 +187,9 @@ class AsrDataModule(DataModule):
]
train = K2SpeechRecognitionDataset(
- cuts_train,
cut_transforms=transforms,
input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
@@ -154,14 +203,13 @@ class AsrDataModule(DataModule):
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
- cuts_train = cuts_train.drop_features()
train = K2SpeechRecognitionDataset(
- cuts=cuts_train,
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
@@ -169,44 +217,60 @@ class AsrDataModule(DataModule):
train_sampler = BucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
- shuffle=True,
+ shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ bucket_method="equal_duration",
+ drop_last=True,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
- shuffle=True,
+ shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
+
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
- num_workers=4,
- persistent_workers=True,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
)
+
return train_dl
def valid_dataloaders(self) -> DataLoader:
logging.info("About to get dev cuts")
cuts_valid = self.valid_cuts()
+ transforms = []
+ if self.args.concatenate_cuts:
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
- cuts_valid = cuts_valid.drop_features()
validate = K2SpeechRecognitionDataset(
- cuts_valid.drop_features(),
+ cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
+ return_cuts=self.args.return_cuts,
)
else:
- validate = K2SpeechRecognitionDataset(cuts_valid)
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
valid_sampler = SingleCutSampler(
cuts_valid,
max_duration=self.args.max_duration,
+ shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
@@ -214,8 +278,9 @@ class AsrDataModule(DataModule):
sampler=valid_sampler,
batch_size=None,
num_workers=2,
- persistent_workers=True,
+ persistent_workers=False,
)
+
return valid_dl
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
@@ -228,10 +293,12 @@ class AsrDataModule(DataModule):
for cuts_test in cuts:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
- cuts_test,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
- ),
+ )
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures(),
+ return_cuts=self.args.return_cuts,
)
sampler = SingleCutSampler(
cuts_test, max_duration=self.args.max_duration
@@ -246,3 +313,42 @@ class AsrDataModule(DataModule):
return test_loaders
else:
return test_loaders[0]
+
+ @lru_cache()
+ def train_cuts(self) -> CutSet:
+ logging.info("About to get train cuts")
+ cuts_train = load_manifest(
+ self.args.feature_dir / "cuts_train-clean-100.json.gz"
+ )
+ if self.args.full_libri:
+ cuts_train = (
+ cuts_train
+ + load_manifest(
+ self.args.feature_dir / "cuts_train-clean-360.json.gz"
+ )
+ + load_manifest(
+ self.args.feature_dir / "cuts_train-other-500.json.gz"
+ )
+ )
+ return cuts_train
+
+ @lru_cache()
+ def valid_cuts(self) -> CutSet:
+ logging.info("About to get dev cuts")
+ cuts_valid = load_manifest(
+ self.args.feature_dir / "cuts_dev-clean.json.gz"
+ ) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz")
+ return cuts_valid
+
+ @lru_cache()
+ def test_cuts(self) -> List[CutSet]:
+ test_sets = ["test-clean", "test-other"]
+ cuts = []
+ for test_set in test_sets:
+ logging.debug("About to get test cuts")
+ cuts.append(
+ load_manifest(
+ self.args.feature_dir / f"cuts_{test_set}.json.gz"
+ )
+ )
+ return cuts
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
index 137fa795c..1e91b1008 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
@@ -1,4 +1,19 @@
#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
@@ -10,10 +25,10 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
from model import TdnnLstm
from icefall.checkpoint import average_checkpoints, load_checkpoint
-from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.decode import (
get_lattice,
nbest_decoding,
@@ -27,6 +42,7 @@ from icefall.utils import (
get_texts,
setup_logger,
store_transcripts,
+ str2bool,
write_error_stats,
)
@@ -39,7 +55,7 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
- default=9,
+ default=19,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
@@ -51,6 +67,57 @@ def get_parser():
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="whole-lattice-rescoring",
+ help="""Decoding method.
+ Supported values are:
+ - (1) 1best. Extract the best path from the decoding lattice as the
+ decoding result.
+ - (2) nbest. Extract n paths from the decoding lattice; the path
+ with the highest score is the decoding result.
+ - (3) nbest-rescoring. Extract n paths from the decoding lattice,
+ rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
+ the highest score is the decoding result.
+ - (4) whole-lattice-rescoring. Rescore the decoding lattice with an
+ n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
+ is the decoding result.
+ """,
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=100,
+ help="""Number of paths for n-best based decoding method.
+ Used only when "method" is one of the following values:
+ nbest, nbest-rescoring
+ """,
+ )
+
+ parser.add_argument(
+ "--lattice-score-scale",
+ type=float,
+ default=0.5,
+ help="""The scale to be applied to `lattice.scores`.
+ It's needed if you use any kinds of n-best based rescoring.
+ Used only when "method" is one of the following values:
+ nbest, nbest-rescoring
+ A smaller value results in more unique paths.
+ """,
+ )
+
+ parser.add_argument(
+ "--export",
+ type=str2bool,
+ default=False,
+ help="""When enabled, the averaged model is saved to
+ tdnn/exp/pretrained.pt. Note: only model.state_dict() is saved.
+ pretrained.pt contains a dict {"model": model.state_dict()},
+ which can be loaded by `icefall.checkpoint.load_checkpoint()`.
+ """,
+ )
return parser
@@ -67,14 +134,6 @@ def get_params() -> AttributeDict:
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
- # Possible values for method:
- # - 1best
- # - nbest
- # - nbest-rescoring
- # - whole-lattice-rescoring
- "method": "1best",
- # num_paths is used when method is "nbest" and "nbest-rescoring"
- "num_paths": 30,
}
)
return params
@@ -131,12 +190,12 @@ def decode_one_batch(
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
- # at entry, feature is [N, T, C]
+ # at entry, feature is (N, T, C)
- feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
+ feature = feature.permute(0, 2, 1) # now feature is (N, C, T)
nnet_output = model(feature)
- # nnet_output is [N, T, C]
+ # nnet_output is (N, T, C)
supervisions = batch["supervisions"]
@@ -170,6 +229,7 @@ def decode_one_batch(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
+ lattice_score_scale=params.lattice_score_scale,
)
key = f"no_rescore-{params.num_paths}"
hyps = get_texts(best_path)
@@ -178,7 +238,8 @@ def decode_one_batch(
assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"]
- lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
+ lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
+ lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.method == "nbest-rescoring":
@@ -187,10 +248,13 @@ def decode_one_batch(
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
+ lattice_score_scale=params.lattice_score_scale,
)
else:
best_path_dict = rescore_with_whole_lattice(
- lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
+ lattice=lattice,
+ G_with_epsilon_loops=G,
+ lm_scale_list=lm_scale_list,
)
ans = dict()
@@ -236,7 +300,11 @@ def decode_dataset(
results = []
num_cuts = 0
- tot_num_cuts = len(dl.dataset.cuts)
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@@ -263,10 +331,10 @@ def decode_dataset(
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
logging.info(
- f"batch {batch_idx}, cuts processed until now is "
- f"{num_cuts}/{tot_num_cuts} "
- f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
+ f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
@@ -328,7 +396,9 @@ def main():
logging.info(f"device: {device}")
- HLG = k2.Fsa.from_dict(torch.load("data/lang_phone/HLG.pt"))
+ HLG = k2.Fsa.from_dict(
+ torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
+ )
HLG = HLG.to(device)
assert HLG.requires_grad is False
@@ -355,7 +425,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
- d = torch.load(params.lm_dir / "G_4_gram.pt")
+ d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
G = k2.Fsa.from_dict(d).to(device)
if params.method == "whole-lattice-rescoring":
@@ -387,6 +457,13 @@ def main():
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
+ if params.export:
+ logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
+ torch.save(
+ {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+ )
+ return
+
model.to(device)
model.eval()
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
index 0dc4228dc..5e04c11b4 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
@@ -1,3 +1,20 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
import torch
import torch.nn as nn
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
new file mode 100755
index 000000000..0a543d859
--- /dev/null
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
@@ -0,0 +1,277 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import torch
+import torchaudio
+from model import TdnnLstm
+from torch.nn.utils.rnn import pad_sequence
+
+from icefall.decode import (
+ get_lattice,
+ one_best_decoding,
+ rescore_with_whole_lattice,
+)
+from icefall.utils import AttributeDict, get_texts
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ required=True,
+ help="Path to the checkpoint. "
+ "The checkpoint is assumed to be saved by "
+ "icefall.checkpoint.save_checkpoint().",
+ )
+
+ parser.add_argument(
+ "--words-file",
+ type=str,
+ required=True,
+ help="Path to words.txt",
+ )
+
+ parser.add_argument(
+ "--HLG", type=str, required=True, help="Path to HLG.pt."
+ )
+
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="1best",
+ help="""Decoding method.
+ Possible values are:
+ (1) 1best - Use the best path as decoding output. Only
+ the transformer encoder output is used for decoding.
+ We call it HLG decoding.
+ (2) whole-lattice-rescoring - Use an LM to rescore the
+ decoding lattice and then use 1best to decode the
+ rescored lattice.
+ We call it HLG decoding + n-gram LM rescoring.
+ """,
+ )
+
+ parser.add_argument(
+ "--G",
+ type=str,
+ help="""An LM for rescoring.
+ Used only when method is
+ whole-lattice-rescoring.
+ It's usually a 4-gram LM.
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.8,
+ help="""
+ Used only when method is whole-lattice-rescoring.
+ It specifies the scale for n-gram LM scores.
+ (Note: You need to tune it on a dataset.)
+ """,
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ params = AttributeDict(
+ {
+ "feature_dim": 80,
+ "subsampling_factor": 3,
+ "num_classes": 72,
+ "sample_rate": 16000,
+ "search_beam": 20,
+ "output_beam": 5,
+ "min_active_states": 30,
+ "max_active_states": 10000,
+ "use_double_scores": True,
+ }
+ )
+ return params
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert sample_rate == expected_sample_rate, (
+ f"expected sample rate: {expected_sample_rate}. "
+ f"Given: {sample_rate}"
+ )
+ # We use only the first channel
+ ans.append(wave[0])
+ return ans
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = get_params()
+ params.update(vars(args))
+ logging.info(f"{params}")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ logging.info("Creating model")
+ model = TdnnLstm(
+ num_features=params.feature_dim,
+ num_classes=params.num_classes,
+ subsampling_factor=params.subsampling_factor,
+ )
+
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ model.load_state_dict(checkpoint["model"])
+ model.to(device)
+ model.eval()
+
+ logging.info(f"Loading HLG from {params.HLG}")
+ HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+ HLG = HLG.to(device)
+ if not hasattr(HLG, "lm_scores"):
+ # For whole-lattice-rescoring and attention-decoder
+ HLG.lm_scores = HLG.scores.clone()
+
+ if params.method == "whole-lattice-rescoring":
+ logging.info(f"Loading G from {params.G}")
+ G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+ # Add epsilon self-loops to G as we will compose
+ # it with the whole lattice later
+ G = G.to(device)
+ G = k2.add_epsilon_self_loops(G)
+ G = k2.arc_sort(G)
+ G.lm_scores = G.scores.clone()
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = params.sample_rate
+ opts.mel_opts.num_bins = params.feature_dim
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {params.sound_files}")
+ waves = read_sound_files(
+ filenames=params.sound_files, expected_sample_rate=params.sample_rate
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+
+ features = pad_sequence(
+ features, batch_first=True, padding_value=math.log(1e-10)
+ )
+ features = features.permute(0, 2, 1) # now features is (N, C, T)
+
+ with torch.no_grad():
+ nnet_output = model(features)
+ # nnet_output is (N, T, C)
+
+ batch_size = nnet_output.shape[0]
+ supervision_segments = torch.tensor(
+ [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+ dtype=torch.int32,
+ )
+
+ lattice = get_lattice(
+ nnet_output=nnet_output,
+ HLG=HLG,
+ supervision_segments=supervision_segments,
+ search_beam=params.search_beam,
+ output_beam=params.output_beam,
+ min_active_states=params.min_active_states,
+ max_active_states=params.max_active_states,
+ subsampling_factor=params.subsampling_factor,
+ )
+
+ if params.method == "1best":
+ logging.info("Use HLG decoding")
+ best_path = one_best_decoding(
+ lattice=lattice, use_double_scores=params.use_double_scores
+ )
+ elif params.method == "whole-lattice-rescoring":
+ logging.info("Use HLG decoding + LM rescoring")
+ best_path_dict = rescore_with_whole_lattice(
+ lattice=lattice,
+ G_with_epsilon_loops=G,
+ lm_scale_list=[params.ngram_lm_scale],
+ )
+ best_path = next(iter(best_path_dict.values()))
+
+ hyps = get_texts(best_path)
+ word_sym_table = k2.SymbolTable.from_file(params.words_file)
+ hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
+
+ s = "\n"
+ for filename, hyp in zip(params.sound_files, hyps):
+ words = " ".join(hyp)
+ s += f"{filename}:\n{words}\n\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
index dbb9f64ec..695ee5130 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
@@ -1,6 +1,20 @@
#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
-# This is just at the very beginning ...
import argparse
import logging
@@ -14,16 +28,16 @@ import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
+from asr_datamodule import LibriSpeechAsrDataModule
from lhotse.utils import fix_random_seed
from model import TdnnLstm
from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.nn.utils import clip_grad_value_
+from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
-from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
@@ -61,9 +75,23 @@ def get_parser():
help="Should various information be logged in tensorboard.",
)
- # TODO: add extra arguments and support DDP training.
- # Currently, only single GPU training is implemented. Will add
- # DDP training once single GPU training is finished.
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=20,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=0,
+ help="""Resume training from from this epoch.
+ If it is positive, it will load checkpoint from
+ tdnn_lstm_ctc/exp/epoch-{start_epoch-1}.pt
+ """,
+ )
+
return parser
@@ -93,11 +121,6 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model.
- - start_epoch: If it is not zero, load checkpoint `start_epoch-1`
- and continue training from that checkpoint.
-
- - num_epochs: Number of epochs to train.
-
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
@@ -116,6 +139,8 @@ def get_params() -> AttributeDict:
- log_interval: Print training loss if batch_idx % log_interval` is 0
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
- valid_interval: Run validation if batch_idx % valid_interval` is 0
- beam_size: It is used in k2.ctc_loss
@@ -132,14 +157,13 @@ def get_params() -> AttributeDict:
"feature_dim": 80,
"weight_decay": 5e-4,
"subsampling_factor": 3,
- "start_epoch": 0,
- "num_epochs": 10,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 10,
+ "reset_interval": 200,
"valid_interval": 1000,
"beam_size": 10,
"reduction": "sum",
@@ -266,14 +290,14 @@ def compute_loss(
"""
device = graph_compiler.device
feature = batch["inputs"]
- # at entry, feature is [N, T, C]
- feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
+ # at entry, feature is (N, T, C)
+ feature = feature.permute(0, 2, 1) # now feature is (N, C, T)
assert feature.ndim == 3
feature = feature.to(device)
with torch.set_grad_enabled(is_training):
nnet_output = model(feature)
- # nnet_output is [N, T, C]
+ # nnet_output is (N, T, C)
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
@@ -387,8 +411,12 @@ def train_one_epoch(
"""
model.train()
- tot_loss = 0.0 # sum of losses over all batches
- tot_frames = 0.0 # sum of frames over all batches
+ tot_loss = 0.0 # reset after params.reset_interval of batches
+ tot_frames = 0.0 # reset after params.reset_interval of batches
+
+ params.tot_loss = 0.0
+ params.tot_frames = 0.0
+
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@@ -406,7 +434,7 @@ def train_one_epoch(
optimizer.zero_grad()
loss.backward()
- clip_grad_value_(model.parameters(), 5.0)
+ clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
loss_cpu = loss.detach().cpu().item()
@@ -415,6 +443,9 @@ def train_one_epoch(
tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
+ params.tot_frames += params.train_frames
+ params.tot_loss += loss_cpu
+
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
@@ -422,6 +453,22 @@ def train_one_epoch(
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
)
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/current_loss",
+ loss_cpu / params.train_frames,
+ params.batch_idx_train,
+ )
+
+ tb_writer.add_scalar(
+ "train/tot_avg_loss",
+ tot_avg_loss,
+ params.batch_idx_train,
+ )
+
+ if batch_idx > 0 and batch_idx % params.reset_interval == 0:
+ tot_loss = 0
+ tot_frames = 0
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss(
@@ -438,7 +485,7 @@ def train_one_epoch(
f"best valid epoch: {params.best_valid_epoch}"
)
- params.train_loss = tot_loss / tot_frames
+ params.train_loss = params.tot_loss / params.tot_frames
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
diff --git a/egs/yesno/ASR/README.md b/egs/yesno/ASR/README.md
new file mode 100644
index 000000000..6f57412c0
--- /dev/null
+++ b/egs/yesno/ASR/README.md
@@ -0,0 +1,14 @@
+## Yesno recipe
+
+This is the simplest ASR recipe in `icefall`.
+
+It can be run on CPU and takes less than 30 seconds to
+get the following WER:
+
+```
+[test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
+```
+
+Please refer to
+
+for detailed instructions.
diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py
new file mode 100755
index 000000000..9b6a4c5ba
--- /dev/null
+++ b/egs/yesno/ASR/local/compile_hlg.py
@@ -0,0 +1,134 @@
+#!/usr/bin/env python3
+
+"""
+This script takes as input lang_dir and generates HLG from
+
+ - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
+ - L, the lexicon, built from lang_dir/L_disambig.pt
+
+ Caution: We use a lexicon that contains disambiguation symbols
+
+ - G, the LM, built from data/lm/G.fst.txt
+
+The generated HLG is saved in $lang_dir/HLG.pt
+"""
+import argparse
+import logging
+from pathlib import Path
+
+import k2
+import torch
+
+from icefall.lexicon import Lexicon
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Input and output directory.
+ """,
+ )
+
+ return parser.parse_args()
+
+
+def compile_HLG(lang_dir: str) -> k2.Fsa:
+ """
+ Args:
+ lang_dir:
+ The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
+
+ Return:
+ An FSA representing HLG.
+ """
+ lexicon = Lexicon(lang_dir)
+ max_token_id = max(lexicon.tokens)
+ logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
+ H = k2.ctc_topo(max_token_id)
+ L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
+
+ logging.info("Loading G.fst.txt")
+ with open("data/lm/G.fst.txt") as f:
+ G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+
+ first_token_disambig_id = lexicon.token_table["#0"]
+ first_word_disambig_id = lexicon.word_table["#0"]
+
+ L = k2.arc_sort(L)
+ G = k2.arc_sort(G)
+
+ logging.info("Intersecting L and G")
+ LG = k2.compose(L, G)
+ logging.info(f"LG shape: {LG.shape}")
+
+ logging.info("Connecting LG")
+ LG = k2.connect(LG)
+ logging.info(f"LG shape after k2.connect: {LG.shape}")
+
+ logging.info(type(LG.aux_labels))
+ logging.info("Determinizing LG")
+
+ LG = k2.determinize(LG)
+ logging.info(type(LG.aux_labels))
+
+ logging.info("Connecting LG after k2.determinize")
+ LG = k2.connect(LG)
+
+ logging.info("Removing disambiguation symbols on LG")
+
+ LG.labels[LG.labels >= first_token_disambig_id] = 0
+
+ assert isinstance(LG.aux_labels, k2.RaggedTensor)
+ LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
+
+ LG = k2.remove_epsilon(LG)
+ logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
+
+ LG = k2.connect(LG)
+ LG.aux_labels = LG.aux_labels.remove_values_eq(0)
+
+ logging.info("Arc sorting LG")
+ LG = k2.arc_sort(LG)
+
+ logging.info("Composing H and LG")
+ # CAUTION: The name of the inner_labels is fixed
+ # to `tokens`. If you want to change it, please
+ # also change other places in icefall that are using
+ # it.
+ HLG = k2.compose(H, LG, inner_labels="tokens")
+
+ logging.info("Connecting LG")
+ HLG = k2.connect(HLG)
+
+ logging.info("Arc sorting LG")
+ HLG = k2.arc_sort(HLG)
+ logging.info(f"HLG.shape: {HLG.shape}")
+
+ return HLG
+
+
+def main():
+ args = get_args()
+ lang_dir = Path(args.lang_dir)
+
+ if (lang_dir / "HLG.pt").is_file():
+ logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
+ return
+
+ logging.info(f"Processing {lang_dir}")
+
+ HLG = compile_HLG(lang_dir)
+ logging.info(f"Saving HLG.pt to {lang_dir}")
+ torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ main()
diff --git a/egs/yesno/ASR/local/compute_fbank_yesno.py b/egs/yesno/ASR/local/compute_fbank_yesno.py
new file mode 100755
index 000000000..dad7319fd
--- /dev/null
+++ b/egs/yesno/ASR/local/compute_fbank_yesno.py
@@ -0,0 +1,81 @@
+#!/usr/bin/env python3
+
+"""
+This file computes fbank features of the yesno dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or it wastes a
+# lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_fbank_yesno():
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/fbank")
+
+ # This dataset is rather small, so we use only one job
+ num_jobs = min(1, os.cpu_count())
+ num_mel_bins = 23
+
+ dataset_parts = (
+ "train",
+ "test",
+ )
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts, output_dir=src_dir
+ )
+ assert manifests is not None
+
+ extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+ with get_executor() as ex: # Initialize the executor only once.
+ for partition, m in manifests.items():
+ if (output_dir / f"cuts_{partition}.json.gz").is_file():
+ logging.info(f"{partition} already exists - skipping.")
+ continue
+ logging.info(f"Processing {partition}")
+ cut_set = CutSet.from_manifests(
+ recordings=m["recordings"],
+ supervisions=m["supervisions"],
+ )
+ if "train" in partition:
+ cut_set = (
+ cut_set
+ + cut_set.perturb_speed(0.9)
+ + cut_set.perturb_speed(1.1)
+ )
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir}/feats_{partition}",
+ # when an executor is specified, make more partitions
+ num_jobs=num_jobs if ex is None else 1, # use one job
+ executor=ex,
+ storage_type=LilcomHdf5Writer,
+ )
+ cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ compute_fbank_yesno()
diff --git a/egs/yesno/ASR/local/prepare_lang.py b/egs/yesno/ASR/local/prepare_lang.py
new file mode 100755
index 000000000..f7fde7796
--- /dev/null
+++ b/egs/yesno/ASR/local/prepare_lang.py
@@ -0,0 +1,367 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
+
+"""
+This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
+consisting of words and tokens (i.e., phones) and does the following:
+
+1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
+
+2. Generate tokens.txt, the token table mapping a token to a unique integer.
+
+3. Generate words.txt, the word table mapping a word to a unique integer.
+
+4. Generate L.pt, in k2 format. It can be loaded by
+
+ d = torch.load("L.pt")
+ lexicon = k2.Fsa.from_dict(d)
+
+5. Generate L_disambig.pt, in k2 format.
+"""
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Any, Dict, List, Tuple
+
+import k2
+import torch
+
+from icefall.lexicon import read_lexicon, write_lexicon
+
+Lexicon = List[Tuple[str, List[str]]]
+
+
+def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
+ """Write a symbol to ID mapping to a file.
+
+ Note:
+ No need to implement `read_mapping` as it can be done
+ through :func:`k2.SymbolTable.from_file`.
+
+ Args:
+ filename:
+ Filename to save the mapping.
+ sym2id:
+ A dict mapping symbols to IDs.
+ Returns:
+ Return None.
+ """
+ with open(filename, "w", encoding="utf-8") as f:
+ for sym, i in sym2id.items():
+ f.write(f"{sym} {i}\n")
+
+
+def get_tokens(lexicon: Lexicon) -> List[str]:
+ """Get tokens from a lexicon.
+
+ Args:
+ lexicon:
+ It is the return value of :func:`read_lexicon`.
+ Returns:
+ Return a list of unique tokens.
+ """
+ ans = set()
+ for _, tokens in lexicon:
+ ans.update(tokens)
+ sorted_ans = sorted(list(ans))
+ return sorted_ans
+
+
+def get_words(lexicon: Lexicon) -> List[str]:
+ """Get words from a lexicon.
+
+ Args:
+ lexicon:
+ It is the return value of :func:`read_lexicon`.
+ Returns:
+ Return a list of unique words.
+ """
+ ans = set()
+ for word, _ in lexicon:
+ ans.add(word)
+ sorted_ans = sorted(list(ans))
+ return sorted_ans
+
+
+def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
+ """It adds pseudo-token disambiguation symbols #1, #2 and so on
+ at the ends of tokens to ensure that all pronunciations are different,
+ and that none is a prefix of another.
+
+ See also add_lex_disambig.pl from kaldi.
+
+ Args:
+ lexicon:
+ It is returned by :func:`read_lexicon`.
+ Returns:
+ Return a tuple with two elements:
+
+ - The output lexicon with disambiguation symbols
+ - The ID of the max disambiguation symbol that appears
+ in the lexicon
+ """
+
+ # (1) Work out the count of each token-sequence in the
+ # lexicon.
+ count = defaultdict(int)
+ for _, tokens in lexicon:
+ count[" ".join(tokens)] += 1
+
+ # (2) For each left sub-sequence of each token-sequence, note down
+ # that it exists (for identifying prefixes of longer strings).
+ issubseq = defaultdict(int)
+ for _, tokens in lexicon:
+ tokens = tokens.copy()
+ tokens.pop()
+ while tokens:
+ issubseq[" ".join(tokens)] = 1
+ tokens.pop()
+
+ # (3) For each entry in the lexicon:
+ # if the token sequence is unique and is not a
+ # prefix of another word, no disambig symbol.
+ # Else output #1, or #2, #3, ... if the same token-seq
+ # has already been assigned a disambig symbol.
+ ans = []
+
+ # We start with #1 since #0 has its own purpose
+ first_allowed_disambig = 1
+ max_disambig = first_allowed_disambig - 1
+ last_used_disambig_symbol_of = defaultdict(int)
+
+ for word, tokens in lexicon:
+ tokenseq = " ".join(tokens)
+ assert tokenseq != ""
+ if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
+ ans.append((word, tokens))
+ continue
+
+ cur_disambig = last_used_disambig_symbol_of[tokenseq]
+ if cur_disambig == 0:
+ cur_disambig = first_allowed_disambig
+ else:
+ cur_disambig += 1
+
+ if cur_disambig > max_disambig:
+ max_disambig = cur_disambig
+ last_used_disambig_symbol_of[tokenseq] = cur_disambig
+ tokenseq += f" #{cur_disambig}"
+ ans.append((word, tokenseq.split()))
+ return ans, max_disambig
+
+
+def generate_id_map(symbols: List[str]) -> Dict[str, int]:
+ """Generate ID maps, i.e., map a symbol to a unique ID.
+
+ Args:
+ symbols:
+ A list of unique symbols.
+ Returns:
+ A dict containing the mapping between symbols and IDs.
+ """
+ return {sym: i for i, sym in enumerate(symbols)}
+
+
+def add_self_loops(
+ arcs: List[List[Any]], disambig_token: int, disambig_word: int
+) -> List[List[Any]]:
+ """Adds self-loops to states of an FST to propagate disambiguation symbols
+ through it. They are added on each state with non-epsilon output symbols
+ on at least one arc out of the state.
+
+ See also fstaddselfloops.pl from Kaldi. One difference is that
+ Kaldi uses OpenFst style FSTs and it has multiple final states.
+ This function uses k2 style FSTs and it does not need to add self-loops
+ to the final state.
+
+ The input label of a self-loop is `disambig_token`, while the output
+ label is `disambig_word`.
+
+ Args:
+ arcs:
+ A list-of-list. The sublist contains
+ `[src_state, dest_state, label, aux_label, score]`
+ disambig_token:
+ It is the token ID of the symbol `#0`.
+ disambig_word:
+ It is the word ID of the symbol `#0`.
+
+ Return:
+ Return new `arcs` containing self-loops.
+ """
+ states_needs_self_loops = set()
+ for arc in arcs:
+ src, dst, ilabel, olabel, score = arc
+ if olabel != 0:
+ states_needs_self_loops.add(src)
+
+ ans = []
+ for s in states_needs_self_loops:
+ ans.append([s, s, disambig_token, disambig_word, 0])
+
+ return arcs + ans
+
+
+def lexicon_to_fst(
+ lexicon: Lexicon,
+ token2id: Dict[str, int],
+ word2id: Dict[str, int],
+ sil_token: str = "SIL",
+ sil_prob: float = 0.5,
+ need_self_loops: bool = False,
+) -> k2.Fsa:
+ """Convert a lexicon to an FST (in k2 format) with optional silence at
+ the beginning and end of each word.
+
+ Args:
+ lexicon:
+ The input lexicon. See also :func:`read_lexicon`
+ token2id:
+ A dict mapping tokens to IDs.
+ word2id:
+ A dict mapping words to IDs.
+ sil_token:
+ The silence token.
+ sil_prob:
+ The probability for adding a silence at the beginning and end
+ of the word.
+ need_self_loops:
+ If True, add self-loop to states with non-epsilon output symbols
+ on at least one arc out of the state. The input label for this
+ self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
+ Returns:
+ Return an instance of `k2.Fsa` representing the given lexicon.
+ """
+ assert sil_prob > 0.0 and sil_prob < 1.0
+ # CAUTION: we use score, i.e, negative cost.
+ sil_score = math.log(sil_prob)
+ no_sil_score = math.log(1.0 - sil_prob)
+
+ start_state = 0
+ loop_state = 1 # words enter and leave from here
+ sil_state = 2 # words terminate here when followed by silence; this state
+ # has a silence transition to loop_state.
+ next_state = 3 # the next un-allocated state, will be incremented as we go.
+ arcs = []
+
+ assert token2id[""] == 0
+ assert word2id[""] == 0
+
+ eps = 0
+
+ sil_token = token2id[sil_token]
+
+ arcs.append([start_state, loop_state, eps, eps, no_sil_score])
+ arcs.append([start_state, sil_state, eps, eps, sil_score])
+ arcs.append([sil_state, loop_state, sil_token, eps, 0])
+
+ for word, tokens in lexicon:
+ assert len(tokens) > 0, f"{word} has no pronunciations"
+ cur_state = loop_state
+
+ word = word2id[word]
+ tokens = [token2id[i] for i in tokens]
+
+ for i in range(len(tokens) - 1):
+ w = word if i == 0 else eps
+ arcs.append([cur_state, next_state, tokens[i], w, 0])
+
+ cur_state = next_state
+ next_state += 1
+
+ # now for the last token of this word
+ # It has two out-going arcs, one to the loop state,
+ # the other one to the sil_state.
+ i = len(tokens) - 1
+ w = word if i == 0 else eps
+ arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
+ arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
+
+ if need_self_loops:
+ disambig_token = token2id["#0"]
+ disambig_word = word2id["#0"]
+ arcs = add_self_loops(
+ arcs,
+ disambig_token=disambig_token,
+ disambig_word=disambig_word,
+ )
+
+ final_state = next_state
+ arcs.append([loop_state, final_state, -1, -1, 0])
+ arcs.append([final_state])
+
+ arcs = sorted(arcs, key=lambda arc: arc[0])
+ arcs = [[str(i) for i in arc] for arc in arcs]
+ arcs = [" ".join(arc) for arc in arcs]
+ arcs = "\n".join(arcs)
+
+ fsa = k2.Fsa.from_str(arcs, acceptor=False)
+ return fsa
+
+
+def main():
+ out_dir = Path("data/lang_phone")
+ lexicon_filename = out_dir / "lexicon.txt"
+ sil_token = "SIL"
+ sil_prob = 0.5
+
+ lexicon = read_lexicon(lexicon_filename)
+ tokens = get_tokens(lexicon)
+ words = get_words(lexicon)
+
+ lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
+
+ for i in range(max_disambig + 1):
+ disambig = f"#{i}"
+ assert disambig not in tokens
+ tokens.append(f"#{i}")
+
+ assert "" not in tokens
+ tokens = [""] + tokens
+
+ assert "" not in words
+ assert "#0" not in words
+ assert "" not in words
+ assert "" not in words
+
+ words = [""] + words + ["#0", "", ""]
+
+ token2id = generate_id_map(tokens)
+ word2id = generate_id_map(words)
+
+ write_mapping(out_dir / "tokens.txt", token2id)
+ write_mapping(out_dir / "words.txt", word2id)
+ write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig)
+
+ L = lexicon_to_fst(
+ lexicon,
+ token2id=token2id,
+ word2id=word2id,
+ sil_token=sil_token,
+ sil_prob=sil_prob,
+ )
+
+ L_disambig = lexicon_to_fst(
+ lexicon_disambig,
+ token2id=token2id,
+ word2id=word2id,
+ sil_token=sil_token,
+ sil_prob=sil_prob,
+ need_self_loops=True,
+ )
+ torch.save(L.as_dict(), out_dir / "L.pt")
+ torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
+
+ if False:
+ # Just for debugging, will remove it
+ L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
+ L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
+ L_disambig.labels_sym = L.labels_sym
+ L_disambig.aux_labels_sym = L.aux_labels_sym
+ L.draw(out_dir / "L.png", title="L")
+ L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/yesno/ASR/prepare.sh b/egs/yesno/ASR/prepare.sh
new file mode 100755
index 000000000..9a0cc48bb
--- /dev/null
+++ b/egs/yesno/ASR/prepare.sh
@@ -0,0 +1,93 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+stage=-1
+stop_stage=100
+
+dl_dir=$PWD/download
+
+lang_dir=data/lang_phone
+lm_dir=data/lm
+
+. shared/parse_options.sh || exit 1
+
+mkdir -p $lang_dir
+mkdir -p $lm_dir
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "stage 0: Download data"
+ mkdir -p $dl_dir
+
+ if [ ! -f $dl_dir/waves_yesno/.completed ]; then
+ lhotse download yesno $dl_dir
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare yesno manifest"
+ mkdir -p data/manifests
+ lhotse prepare yesno $dl_dir/waves_yesno data/manifests
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Compute fbank for yesno"
+ mkdir -p data/fbank
+ ./local/compute_fbank_yesno.py
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Prepare lang"
+ # NOTE: " SIL" is added for implementation convenience
+ # as the graph compiler code requires that there is a OOV word
+ # in the lexicon.
+ (
+ echo " SIL"
+ echo "YES Y"
+ echo "NO N"
+ echo " SIL"
+ ) > $lang_dir/lexicon.txt
+
+ ./local/prepare_lang.py
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Prepare G"
+ # We use a unigram G
+ cat < $lm_dir/G.arpa
+
+\data\\
+ngram 1=4
+
+\1-grams:
+-1 NO
+-1 YES
+-99
+-1
+
+\end\\
+
+EOF
+
+ if [ ! -f $lm_dir/G.fst.txt ]; then
+ python3 -m kaldilm \
+ --read-symbol-table="$lang_dir/words.txt" \
+ --disambig-symbol='#0' \
+ $lm_dir/G.arpa > $lm_dir/G.fst.txt
+ fi
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Compile HLG"
+ if [ ! -f $lang_dir/HLG.pt ]; then
+ ./local/compile_hlg.py --lang-dir $lang_dir
+ fi
+fi
diff --git a/egs/yesno/ASR/shared b/egs/yesno/ASR/shared
new file mode 120000
index 000000000..4c5e91438
--- /dev/null
+++ b/egs/yesno/ASR/shared
@@ -0,0 +1 @@
+../../../icefall/shared/
\ No newline at end of file
diff --git a/egs/yesno/ASR/tdnn/README.md b/egs/yesno/ASR/tdnn/README.md
new file mode 100644
index 000000000..2b6116f0a
--- /dev/null
+++ b/egs/yesno/ASR/tdnn/README.md
@@ -0,0 +1,8 @@
+
+## How to run this recipe
+
+You can find detailed instructions by visiting
+
+
+It describes how to run this recipe and how to use
+a pre-trained model with `./pretrained.py`.
diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py
new file mode 100644
index 000000000..e6614e3ce
--- /dev/null
+++ b/egs/yesno/ASR/tdnn/asr_datamodule.py
@@ -0,0 +1,248 @@
+# Copyright 2021 Piotr Żelasko
+# 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import List
+
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest
+from lhotse.dataset import (
+ BucketingSampler,
+ CutConcatenate,
+ K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
+ SingleCutSampler,
+)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from torch.utils.data import DataLoader
+
+from icefall.dataset.datamodule import DataModule
+from icefall.utils import str2bool
+
+
+class YesNoAsrDataModule(DataModule):
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+ """
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ super().add_arguments(parser)
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--feature-dir",
+ type=Path,
+ default=Path("data/fbank"),
+ help="Path to directory with train/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=30.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=False,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=10,
+ help="The number of buckets for the BucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding.",
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help="Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch.",
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help="The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used.",
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ def train_dataloaders(self) -> DataLoader:
+ logging.info("About to get train cuts")
+ cuts_train = self.train_cuts()
+
+ logging.info("About to create train dataset")
+ transforms = []
+ if self.args.concatenate_cuts:
+ logging.info(
+ f"Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ # NOTE: the PerturbSpeed transform should be added only if we
+ # remove it from data prep stage.
+ # Add on-the-fly speed perturbation; since originally it would
+ # have increased epoch size by 3, we will apply prob 2/3 and use
+ # 3x more epochs.
+ # Speed perturbation probably should come first before
+ # concatenation, but in principle the transforms order doesn't have
+ # to be strict (e.g. could be randomized)
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
+ # Drop feats to be on the safe side.
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(
+ Fbank(FbankConfig(num_mel_bins=23))
+ ),
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using BucketingSampler.")
+ train_sampler = BucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ bucket_method="equal_duration",
+ drop_last=True,
+ )
+ else:
+ logging.info("Using SingleCutSampler.")
+ train_sampler = SingleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ )
+
+ return train_dl
+
+ def test_dataloaders(self) -> DataLoader:
+ logging.info("About to get test cuts")
+ cuts_test = self.test_cuts()
+
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23)))
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = SingleCutSampler(
+ cuts_test, max_duration=self.args.max_duration
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test, batch_size=None, sampler=sampler, num_workers=1
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_cuts(self) -> CutSet:
+ logging.info("About to get train cuts")
+ cuts_train = load_manifest(self.args.feature_dir / "cuts_train.json.gz")
+ return cuts_train
+
+ @lru_cache()
+ def test_cuts(self) -> List[CutSet]:
+ logging.info("About to get test cuts")
+ cuts_test = load_manifest(self.args.feature_dir / "cuts_test.json.gz")
+ return cuts_test
diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py
new file mode 100755
index 000000000..325acf316
--- /dev/null
+++ b/egs/yesno/ASR/tdnn/decode.py
@@ -0,0 +1,322 @@
+#!/usr/bin/env python3
+
+
+import argparse
+import logging
+from pathlib import Path
+from typing import List, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from asr_datamodule import YesNoAsrDataModule
+from model import Tdnn
+
+from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.decode import get_lattice, one_best_decoding
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ get_texts,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=14,
+ help="It specifies the checkpoint to use for decoding."
+ "Note: Epoch counts from 0.",
+ )
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=2,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch'. ",
+ )
+
+ parser.add_argument(
+ "--export",
+ type=str2bool,
+ default=False,
+ help="""When enabled, the averaged model is saved to
+ tdnn/exp/pretrained.pt. Note: only model.state_dict() is saved.
+ pretrained.pt contains a dict {"model": model.state_dict()},
+ which can be loaded by `icefall.checkpoint.load_checkpoint()`.
+ """,
+ )
+ return parser
+
+
+def get_params() -> AttributeDict:
+ params = AttributeDict(
+ {
+ "exp_dir": Path("tdnn/exp/"),
+ "lang_dir": Path("data/lang_phone"),
+ "lm_dir": Path("data/lm"),
+ "feature_dim": 23,
+ "search_beam": 20,
+ "output_beam": 8,
+ "min_active_states": 30,
+ "max_active_states": 10000,
+ "use_double_scores": True,
+ }
+ )
+ return params
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ HLG: k2.Fsa,
+ batch: dict,
+ word_table: k2.SymbolTable,
+) -> List[List[int]]:
+ """Decode one batch and return the result in a list-of-list.
+ Each sub list contains the word IDs for an utterance in the batch.
+
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+
+ - params.method is "1best", it uses 1best decoding.
+ - params.method is "nbest", it uses nbest decoding.
+
+ model:
+ The neural model.
+ HLG:
+ The decoding graph.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ (https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py)
+ word_table:
+ It is the word symbol table.
+ Returns:
+ Return the decoding result. `len(ans)` == batch size.
+ """
+ device = HLG.device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ nnet_output = model(feature)
+ # nnet_output is (N, T, C)
+
+ batch_size = nnet_output.shape[0]
+ supervision_segments = torch.tensor(
+ [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+ dtype=torch.int32,
+ )
+
+ lattice = get_lattice(
+ nnet_output=nnet_output,
+ HLG=HLG,
+ supervision_segments=supervision_segments,
+ search_beam=params.search_beam,
+ output_beam=params.output_beam,
+ min_active_states=params.min_active_states,
+ max_active_states=params.max_active_states,
+ )
+
+ best_path = one_best_decoding(
+ lattice=lattice, use_double_scores=params.use_double_scores
+ )
+ hyps = get_texts(best_path)
+ hyps = [[word_table[i] for i in ids] for ids in hyps]
+ return hyps
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ HLG: k2.Fsa,
+ word_table: k2.SymbolTable,
+) -> List[Tuple[List[int], List[int]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ HLG:
+ The decoding graph.
+ word_table:
+ It is word symbol table.
+ Returns:
+ Return a tuple contains two elements (ref_text, hyp_text):
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ results = []
+
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ results = []
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+
+ hyps = decode_one_batch(
+ params=params,
+ model=model,
+ HLG=HLG,
+ batch=batch,
+ word_table=word_table,
+ )
+
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for hyp_words, ref_text in zip(hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((ref_words, hyp_words))
+
+ results.extend(this_batch)
+
+ num_cuts += len(batch["supervisions"]["text"])
+
+ if batch_idx % 100 == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(
+ f"batch {batch_str}, cuts processed until now is {num_cuts}"
+ )
+ return results
+
+
+def save_results(
+ exp_dir: Path,
+ test_set_name: str,
+ results: List[Tuple[List[int], List[int]]],
+) -> None:
+ """Save results to `exp_dir`.
+ Args:
+ exp_dir:
+ The output directory. This function create the following files inside
+ this directory:
+
+ - recogs-{test_set_name}.text
+
+ It contains the reference and hypothesis results, like below::
+
+ ref=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES']
+ hyp=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES']
+ ref=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES']
+ hyp=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES']
+
+ - errs-{test_set_name}.txt
+
+ It contains the detailed WER.
+ test_set_name:
+ The name of the test set, which will be part of the result filename.
+ results:
+ A list of tuples, each of which contains (ref_words, hyp_words).
+ Returns:
+ Return None.
+ """
+ recog_path = exp_dir / f"recogs-{test_set_name}.txt"
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = exp_dir / f"errs-{test_set_name}.txt"
+ with open(errs_filename, "w") as f:
+ write_error_stats(f, f"{test_set_name}", results)
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ YesNoAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+
+ params = get_params()
+ params.update(vars(args))
+
+ setup_logger(f"{params.exp_dir}/log/log-decode")
+ logging.info("Decoding started")
+ logging.info(params)
+
+ lexicon = Lexicon(params.lang_dir)
+ max_token_id = max(lexicon.tokens)
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ HLG = k2.Fsa.from_dict(
+ torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
+ )
+ HLG = HLG.to(device)
+ assert HLG.requires_grad is False
+
+ model = Tdnn(
+ num_features=params.feature_dim,
+ num_classes=max_token_id + 1, # +1 for the blank symbol
+ )
+ if params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.load_state_dict(average_checkpoints(filenames))
+
+ if params.export:
+ logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
+ torch.save(
+ {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+ )
+ return
+
+ model.to(device)
+ model.eval()
+
+ yes_no = YesNoAsrDataModule(args)
+ test_dl = yes_no.test_dataloaders()
+ results = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ HLG=HLG,
+ word_table=lexicon.word_table,
+ )
+
+ save_results(
+ exp_dir=params.exp_dir, test_set_name="test_set", results=results
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/yesno/ASR/tdnn/model.py b/egs/yesno/ASR/tdnn/model.py
new file mode 100755
index 000000000..52cff37e0
--- /dev/null
+++ b/egs/yesno/ASR/tdnn/model.py
@@ -0,0 +1,81 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2021 Xiaomi Corp. (author: Fangjun Kuang)
+
+
+import torch
+import torch.nn as nn
+
+
+class Tdnn(nn.Module):
+ def __init__(self, num_features: int, num_classes: int):
+ """
+ Args:
+ num_features:
+ Model input dimension.
+ num_classes:
+ Model output dimension
+ """
+ super().__init__()
+
+ self.tdnn = nn.Sequential(
+ nn.Conv1d(
+ in_channels=num_features,
+ out_channels=32,
+ kernel_size=3,
+ ),
+ nn.ReLU(inplace=True),
+ nn.BatchNorm1d(num_features=32, affine=False),
+ nn.Conv1d(
+ in_channels=32,
+ out_channels=32,
+ kernel_size=5,
+ dilation=2,
+ ),
+ nn.ReLU(inplace=True),
+ nn.BatchNorm1d(num_features=32, affine=False),
+ nn.Conv1d(
+ in_channels=32,
+ out_channels=32,
+ kernel_size=5,
+ dilation=4,
+ ),
+ nn.ReLU(inplace=True),
+ nn.BatchNorm1d(num_features=32, affine=False),
+ )
+ self.output_linear = nn.Linear(in_features=32, out_features=num_classes)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x:
+ The input tensor with shape [N, T, C]
+
+ Returns:
+ The output tensor has shape [N, T, C]
+ """
+ x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
+ x = self.tdnn(x)
+ x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
+ x = self.output_linear(x)
+ x = nn.functional.log_softmax(x, dim=-1)
+ return x
+
+
+def test_tdnn():
+ num_features = 23
+ num_classes = 4
+ model = Tdnn(num_features=num_features, num_classes=num_classes)
+ num_param = sum([p.numel() for p in model.parameters()])
+ print(f"Number of model parameters: {num_param}")
+ N = 2
+ T = 100
+ C = num_features
+ x = torch.randn(N, T, C)
+ y = model(x)
+ print(x.shape)
+ print(y.shape)
+
+
+if __name__ == "__main__":
+ test_tdnn()
diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py
new file mode 100755
index 000000000..fb92110e3
--- /dev/null
+++ b/egs/yesno/ASR/tdnn/pretrained.py
@@ -0,0 +1,209 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import torch
+import torchaudio
+from model import Tdnn
+from torch.nn.utils.rnn import pad_sequence
+
+from icefall.decode import get_lattice, one_best_decoding
+from icefall.utils import AttributeDict, get_texts
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ required=True,
+ help="Path to the checkpoint. "
+ "The checkpoint is assumed to be saved by "
+ "icefall.checkpoint.save_checkpoint().",
+ )
+
+ parser.add_argument(
+ "--words-file",
+ type=str,
+ required=True,
+ help="Path to words.txt",
+ )
+
+ parser.add_argument(
+ "--HLG", type=str, required=True, help="Path to HLG.pt."
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ params = AttributeDict(
+ {
+ "feature_dim": 23,
+ "num_classes": 4, # [, N, SIL, Y]
+ "sample_rate": 8000,
+ "search_beam": 20,
+ "output_beam": 8,
+ "min_active_states": 30,
+ "max_active_states": 10000,
+ "use_double_scores": True,
+ }
+ )
+ return params
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert sample_rate == expected_sample_rate, (
+ f"expected sample rate: {expected_sample_rate}. "
+ f"Given: {sample_rate}"
+ )
+ # We use only the first channel
+ ans.append(wave[0])
+ return ans
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = get_params()
+ params.update(vars(args))
+ logging.info(f"{params}")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ logging.info("Creating model")
+
+ model = Tdnn(
+ num_features=params.feature_dim,
+ num_classes=params.num_classes,
+ )
+
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ model.load_state_dict(checkpoint["model"])
+ model.to(device)
+ model.eval()
+
+ logging.info(f"Loading HLG from {params.HLG}")
+ HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+ HLG = HLG.to(device)
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = params.sample_rate
+ opts.mel_opts.num_bins = params.feature_dim
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {params.sound_files}")
+ waves = read_sound_files(
+ filenames=params.sound_files, expected_sample_rate=params.sample_rate
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+
+ features = pad_sequence(
+ features, batch_first=True, padding_value=math.log(1e-10)
+ )
+
+ # Note: We don't use key padding mask for attention during decoding
+ with torch.no_grad():
+ nnet_output = model(features)
+
+ batch_size = nnet_output.shape[0]
+ supervision_segments = torch.tensor(
+ [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+ dtype=torch.int32,
+ )
+
+ lattice = get_lattice(
+ nnet_output=nnet_output,
+ HLG=HLG,
+ supervision_segments=supervision_segments,
+ search_beam=params.search_beam,
+ output_beam=params.output_beam,
+ min_active_states=params.min_active_states,
+ max_active_states=params.max_active_states,
+ )
+
+ best_path = one_best_decoding(
+ lattice=lattice, use_double_scores=params.use_double_scores
+ )
+
+ hyps = get_texts(best_path)
+ word_sym_table = k2.SymbolTable.from_file(params.words_file)
+ hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
+
+ s = "\n"
+ for filename, hyp in zip(params.sound_files, hyps):
+ words = " ".join(hyp)
+ s += f"{filename}:\n{words}\n\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py
new file mode 100755
index 000000000..0f5506d38
--- /dev/null
+++ b/egs/yesno/ASR/tdnn/train.py
@@ -0,0 +1,584 @@
+#!/usr/bin/env python3
+
+import argparse
+import logging
+from pathlib import Path
+from shutil import copyfile
+from typing import Optional
+
+import k2
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import torch.nn as nn
+import torch.optim as optim
+from asr_datamodule import YesNoAsrDataModule
+from lhotse.utils import fix_random_seed
+from model import Tdnn
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.nn.utils import clip_grad_norm_
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall.checkpoint import load_checkpoint
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.graph_compiler import CtcTrainingGraphCompiler
+from icefall.lexicon import Lexicon
+from icefall.utils import AttributeDict, setup_logger, str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=15,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=0,
+ help="""Resume training from from this epoch.
+ If it is positive, it will load checkpoint from
+ tdnn/exp/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ is saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - exp_dir: It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+
+ - lang_dir: It contains language related input files such as
+ "lexicon.txt"
+
+ - lr: It specifies the initial learning rate
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - weight_decay: The weight_decay for the optimizer.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - start_epoch: If it is not zero, load checkpoint `start_epoch-1`
+ and continue training from that checkpoint.
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval` is 0
+
+ - beam_size: It is used in k2.ctc_loss
+
+ - reduction: It is used in k2.ctc_loss
+
+ - use_double_scores: It is used in k2.ctc_loss
+ """
+ params = AttributeDict(
+ {
+ "exp_dir": Path("tdnn/exp"),
+ "lang_dir": Path("data/lang_phone"),
+ "lr": 1e-2,
+ "feature_dim": 23,
+ "weight_decay": 1e-6,
+ "start_epoch": 0,
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 10,
+ "valid_interval": 10,
+ "beam_size": 10,
+ "reduction": "sum",
+ "use_double_scores": True,
+ }
+ )
+
+ return params
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+) -> None:
+ """Load checkpoint from file.
+
+ If params.start_epoch is positive, it will load the checkpoint from
+ `params.start_epoch - 1`. Otherwise, this function does nothing.
+
+ Apart from loading state dict for `model`, `optimizer` and `scheduler`,
+ it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The learning rate scheduler we are using.
+ Returns:
+ Return None.
+ """
+ if params.start_epoch <= 0:
+ return
+
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: torch.optim.Optimizer,
+ scheduler: torch.optim.lr_scheduler._LRScheduler,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: nn.Module,
+ batch: dict,
+ graph_compiler: CtcTrainingGraphCompiler,
+ is_training: bool,
+):
+ """
+ Compute CTC loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Tdnn in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ graph_compiler:
+ It is used to build a decoding graph from a ctc topo and training
+ transcript. The training transcript is contained in the given `batch`,
+ while the ctc topo is built when this compiler is instantiated.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ """
+ device = graph_compiler.device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ with torch.set_grad_enabled(is_training):
+ nnet_output = model(feature)
+ # nnet_output is (N, T, C)
+
+ # NOTE: We need `encode_supervisions` to sort sequences with
+ # different duration in decreasing order, required by
+ # `k2.intersect_dense` called in `k2.ctc_loss`
+ supervisions = batch["supervisions"]
+ texts = supervisions["text"]
+
+ batch_size = nnet_output.shape[0]
+ supervision_segments = torch.tensor(
+ [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+ dtype=torch.int32,
+ )
+
+ decoding_graph = graph_compiler.compile(texts)
+
+ dense_fsa_vec = k2.DenseFsaVec(
+ nnet_output,
+ supervision_segments,
+ )
+
+ loss = k2.ctc_loss(
+ decoding_graph=decoding_graph,
+ dense_fsa_vec=dense_fsa_vec,
+ output_beam=params.beam_size,
+ reduction=params.reduction,
+ use_double_scores=params.use_double_scores,
+ )
+
+ assert loss.requires_grad == is_training
+
+ # train_frames and valid_frames are used for printing.
+ if is_training:
+ params.train_frames = supervision_segments[:, 2].sum().item()
+ else:
+ params.valid_frames = supervision_segments[:, 2].sum().item()
+
+ return loss
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: nn.Module,
+ graph_compiler: CtcTrainingGraphCompiler,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> None:
+ """Run the validation process. The validation loss
+ is saved in `params.valid_loss`.
+ """
+ model.eval()
+
+ tot_loss = 0.0
+ tot_frames = 0.0
+ for batch_idx, batch in enumerate(valid_dl):
+ loss = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ graph_compiler=graph_compiler,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+
+ loss_cpu = loss.detach().cpu().item()
+ tot_loss += loss_cpu
+ tot_frames += params.valid_frames
+
+ if world_size > 1:
+ s = torch.tensor([tot_loss, tot_frames], device=loss.device)
+ dist.all_reduce(s, op=dist.ReduceOp.SUM)
+ s = s.cpu().tolist()
+ tot_loss = s[0]
+ tot_frames = s[1]
+
+ params.valid_loss = tot_loss / tot_frames
+
+ if params.valid_loss < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = params.valid_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: CtcTrainingGraphCompiler,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ graph_compiler:
+ It is used to convert transcripts to FSAs.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ """
+ model.train()
+
+ tot_loss = 0.0 # sum of losses over all batches
+ tot_frames = 0.0 # sum of frames over all batches
+ for batch_idx, batch in enumerate(train_dl):
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ loss = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ graph_compiler=graph_compiler,
+ is_training=True,
+ )
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+
+ optimizer.zero_grad()
+ loss.backward()
+ clip_grad_norm_(model.parameters(), 5.0, 2.0)
+ optimizer.step()
+
+ loss_cpu = loss.detach().cpu().item()
+
+ tot_frames += params.train_frames
+ tot_loss += loss_cpu
+ tot_avg_loss = tot_loss / tot_frames
+
+ if batch_idx % params.log_interval == 0:
+ logging.info(
+ f"Epoch {params.cur_epoch}, batch {batch_idx}, "
+ f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
+ f"total avg loss: {tot_avg_loss:.4f}, "
+ f"batch size: {batch_size}"
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/current_loss",
+ loss_cpu / params.train_frames,
+ params.batch_idx_train,
+ )
+
+ tb_writer.add_scalar(
+ "train/tot_avg_loss",
+ tot_avg_loss,
+ params.batch_idx_train,
+ )
+
+ if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+ compute_validation_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(
+ f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f},"
+ f" best valid loss: {params.best_valid_loss:.4f} "
+ f"best valid epoch: {params.best_valid_epoch}"
+ )
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/valid_loss",
+ params.valid_loss,
+ params.batch_idx_train,
+ )
+
+ params.train_loss = tot_loss / tot_frames
+
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(42)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+ logging.info(params)
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ lexicon = Lexicon(params.lang_dir)
+ max_phone_id = max(lexicon.tokens)
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+
+ graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
+
+ model = Tdnn(
+ num_features=params.feature_dim,
+ num_classes=max_phone_id + 1, # +1 for the blank symbol
+ )
+
+ checkpoints = load_checkpoint_if_available(params=params, model=model)
+
+ model.to(device)
+ if world_size > 1:
+ model = DDP(model, device_ids=[rank])
+
+ optimizer = optim.SGD(
+ model.parameters(),
+ lr=params.lr,
+ weight_decay=params.weight_decay,
+ )
+
+ if checkpoints:
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ yes_no = YesNoAsrDataModule(args)
+ train_dl = yes_no.train_dataloaders()
+
+ # There are only 60 waves: 30 files are used for training
+ # and the remaining 30 files are used for testing.
+ # We use test data as validation.
+ valid_dl = yes_no.test_dataloaders()
+
+ for epoch in range(params.start_epoch, params.num_epochs):
+ train_dl.sampler.set_epoch(epoch)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ optimizer=optimizer,
+ graph_compiler=graph_compiler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ )
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ optimizer=optimizer,
+ scheduler=None,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def main():
+ parser = get_parser()
+ YesNoAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/icefall/bpe_graph_compiler.py b/icefall/bpe_graph_compiler.py
index c28de42bf..813b15f76 100644
--- a/icefall/bpe_graph_compiler.py
+++ b/icefall/bpe_graph_compiler.py
@@ -1,3 +1,20 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
from pathlib import Path
from typing import List, Union
diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py
index e45df4fe4..b8a628f4e 100644
--- a/icefall/checkpoint.py
+++ b/icefall/checkpoint.py
@@ -1,3 +1,20 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
@@ -91,7 +108,7 @@ def load_checkpoint(
checkpoint.pop("model")
def load(name, obj):
- s = checkpoint[name]
+ s = checkpoint.get(name, None)
if obj and s:
obj.load_state_dict(s)
checkpoint.pop(name)
diff --git a/icefall/dataset/datamodule.py b/icefall/dataset/datamodule.py
index 8560c5db0..97918ffd8 100644
--- a/icefall/dataset/datamodule.py
+++ b/icefall/dataset/datamodule.py
@@ -1,3 +1,20 @@
+# Copyright 2021 Piotr Żelasko
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
import argparse
from typing import List, Union
diff --git a/icefall/dataset/librispeech.py b/icefall/dataset/librispeech.py
deleted file mode 100644
index 5c18041ed..000000000
--- a/icefall/dataset/librispeech.py
+++ /dev/null
@@ -1,68 +0,0 @@
-import argparse
-import logging
-from functools import lru_cache
-from typing import List
-
-from lhotse import CutSet, load_manifest
-
-from icefall.dataset.asr_datamodule import AsrDataModule
-from icefall.utils import str2bool
-
-
-class LibriSpeechAsrDataModule(AsrDataModule):
- """
- LibriSpeech ASR data module. Can be used for 100h subset
- (``--full-libri false``) or full 960h set.
- The train and valid cuts for standard Libri splits are
- concatenated into a single CutSet/DataLoader.
- """
-
- @classmethod
- def add_arguments(cls, parser: argparse.ArgumentParser):
- super().add_arguments(parser)
- group = parser.add_argument_group(title="LibriSpeech specific options")
- group.add_argument(
- "--full-libri",
- type=str2bool,
- default=True,
- help="When enabled, use 960h LibriSpeech.",
- )
-
- @lru_cache()
- def train_cuts(self) -> CutSet:
- logging.info("About to get train cuts")
- cuts_train = load_manifest(
- self.args.feature_dir / "cuts_train-clean-100.json.gz"
- )
- if self.args.full_libri:
- cuts_train = (
- cuts_train
- + load_manifest(
- self.args.feature_dir / "cuts_train-clean-360.json.gz"
- )
- + load_manifest(
- self.args.feature_dir / "cuts_train-other-500.json.gz"
- )
- )
- return cuts_train
-
- @lru_cache()
- def valid_cuts(self) -> CutSet:
- logging.info("About to get dev cuts")
- cuts_valid = load_manifest(
- self.args.feature_dir / "cuts_dev-clean.json.gz"
- ) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz")
- return cuts_valid
-
- @lru_cache()
- def test_cuts(self) -> List[CutSet]:
- test_sets = ["test-clean", "test-other"]
- cuts = []
- for test_set in test_sets:
- logging.debug("About to get test cuts")
- cuts.append(
- load_manifest(
- self.args.feature_dir / f"cuts_{test_set}.json.gz"
- )
- )
- return cuts
diff --git a/icefall/decode.py b/icefall/decode.py
index 0e9baf2e4..e678e4622 100644
--- a/icefall/decode.py
+++ b/icefall/decode.py
@@ -1,9 +1,26 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import logging
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Dict, List, Optional, Union
import k2
import torch
-import torch.nn as nn
+
+from icefall.utils import get_texts
def _intersect_device(
@@ -18,7 +35,7 @@ def _intersect_device(
CUDA OOM error.
The arguments and return value of this function are the same as
- k2.intersect_device.
+ :func:`k2.intersect_device`.
"""
num_fsas = b_fsas.shape[0]
if num_fsas <= batch_size:
@@ -37,8 +54,8 @@ def _intersect_device(
for start, end in splits:
indexes = torch.arange(start, end).to(b_to_a_map)
- fsas = k2.index(b_fsas, indexes)
- b_to_a = k2.index(b_to_a_map, indexes)
+ fsas = k2.index_fsa(b_fsas, indexes)
+ b_to_a = k2.index_select(b_to_a_map, indexes)
path_lattice = k2.intersect_device(
a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a
)
@@ -59,10 +76,9 @@ def get_lattice(
) -> k2.Fsa:
"""Get the decoding lattice from a decoding graph and neural
network output.
-
Args:
nnet_output:
- It is the output of a neural model of shape `[N, T, C]`.
+ It is the output of a neural model of shape `(N, T, C)`.
HLG:
An Fsa, the decoding graph. See also `compile_HLG.py`.
supervision_segments:
@@ -92,10 +108,12 @@ def get_lattice(
subsampling_factor:
The subsampling factor of the model.
Returns:
- A lattice containing the decoding result.
+ An FsaVec containing the decoding result. It has axes [utt][state][arc].
"""
dense_fsa_vec = k2.DenseFsaVec(
- nnet_output, supervision_segments, allow_truncate=subsampling_factor - 1
+ nnet_output,
+ supervision_segments,
+ allow_truncate=subsampling_factor - 1,
)
lattice = k2.intersect_dense_pruned(
@@ -110,8 +128,304 @@ def get_lattice(
return lattice
+class Nbest(object):
+ """
+ An Nbest object contains two fields:
+
+ (1) fsa. It is an FsaVec containing a vector of **linear** FSAs.
+ Its axes are [path][state][arc]
+ (2) shape. Its type is :class:`k2.RaggedShape`.
+ Its axes are [utt][path]
+
+ The field `shape` has two axes [utt][path]. `shape.dim0` contains
+ the number of utterances, which is also the number of rows in the
+ supervision_segments. `shape.tot_size(1)` contains the number
+ of paths, which is also the number of FSAs in `fsa`.
+
+ Caution:
+ Don't be confused by the name `Nbest`. The best in the name `Nbest`
+ has nothing to do with `best scores`. The important part is
+ `N` in `Nbest`, not `best`.
+ """
+
+ def __init__(self, fsa: k2.Fsa, shape: k2.RaggedShape) -> None:
+ """
+ Args:
+ fsa:
+ An FsaVec with axes [path][state][arc]. It is expected to contain
+ a list of **linear** FSAs.
+ shape:
+ A ragged shape with two axes [utt][path].
+ """
+ assert len(fsa.shape) == 3, f"fsa.shape: {fsa.shape}"
+ assert shape.num_axes == 2, f"num_axes: {shape.num_axes}"
+
+ if fsa.shape[0] != shape.tot_size(1):
+ raise ValueError(
+ f"{fsa.shape[0]} vs {shape.tot_size(1)}\n"
+ "Number of FSAs in `fsa` does not match the given shape"
+ )
+
+ self.fsa = fsa
+ self.shape = shape
+
+ def __str__(self):
+ s = "Nbest("
+ s += f"Number of utterances:{self.shape.dim0}, "
+ s += f"Number of Paths:{self.fsa.shape[0]})"
+ return s
+
+ @staticmethod
+ def from_lattice(
+ lattice: k2.Fsa,
+ num_paths: int,
+ use_double_scores: bool = True,
+ lattice_score_scale: float = 0.5,
+ ) -> "Nbest":
+ """Construct an Nbest object by **sampling** `num_paths` from a lattice.
+
+ Each sampled path is a linear FSA.
+
+ We assume `lattice.labels` contains token IDs and `lattice.aux_labels`
+ contains word IDs.
+
+ Args:
+ lattice:
+ An FsaVec with axes [utt][state][arc].
+ num_paths:
+ Number of paths to **sample** from the lattice
+ using :func:`k2.random_paths`.
+ use_double_scores:
+ True to use double precision in :func:`k2.random_paths`.
+ False to use single precision.
+ scale:
+ Scale `lattice.score` before passing it to :func:`k2.random_paths`.
+ A smaller value leads to more unique paths at the risk of being not
+ to sample the path with the best score.
+ Returns:
+ Return an Nbest instance.
+ """
+ saved_scores = lattice.scores.clone()
+ lattice.scores *= lattice_score_scale
+ # path is a ragged tensor with dtype torch.int32.
+ # It has three axes [utt][path][arc_pos]
+ path = k2.random_paths(
+ lattice, num_paths=num_paths, use_double_scores=use_double_scores
+ )
+ lattice.scores = saved_scores
+
+ # word_seq is a k2.RaggedTensor sharing the same shape as `path`
+ # but it contains word IDs. Note that it also contains 0s and -1s.
+ # The last entry in each sublist is -1.
+ # It axes is [utt][path][word_id]
+ if isinstance(lattice.aux_labels, torch.Tensor):
+ word_seq = k2.ragged.index(lattice.aux_labels, path)
+ else:
+ word_seq = lattice.aux_labels.index(path)
+ word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
+
+ # Each utterance has `num_paths` paths but some of them transduces
+ # to the same word sequence, so we need to remove repeated word
+ # sequences within an utterance. After removing repeats, each utterance
+ # contains different number of paths
+ #
+ # `new2old` is a 1-D torch.Tensor mapping from the output path index
+ # to the input path index.
+ _, _, new2old = word_seq.unique(
+ need_num_repeats=False, need_new2old_indexes=True
+ )
+
+ # kept_path is a ragged tensor with dtype torch.int32.
+ # It has axes [utt][path][arc_pos]
+ kept_path, _ = path.index(new2old, axis=1, need_value_indexes=False)
+
+ # utt_to_path_shape has axes [utt][path]
+ utt_to_path_shape = kept_path.shape.get_layer(0)
+
+ # Remove the utterance axis.
+ # Now kept_path has only two axes [path][arc_pos]
+ kept_path = kept_path.remove_axis(0)
+
+ # labels is a ragged tensor with 2 axes [path][token_id]
+ # Note that it contains -1s.
+ labels = k2.ragged.index(lattice.labels.contiguous(), kept_path)
+
+ # Remove -1 from labels as we will use it to construct a linear FSA
+ labels = labels.remove_values_eq(-1)
+
+ if isinstance(lattice.aux_labels, k2.RaggedTensor):
+ # lattice.aux_labels is a ragged tensor with dtype torch.int32.
+ # It has 2 axes [arc][word], so aux_labels is also a ragged tensor
+ # with 2 axes [arc][word]
+ aux_labels, _ = lattice.aux_labels.index(
+ indexes=kept_path.values, axis=0, need_value_indexes=False
+ )
+ else:
+ assert isinstance(lattice.aux_labels, torch.Tensor)
+ aux_labels = k2.index_select(lattice.aux_labels, kept_path.values)
+ # aux_labels is a 1-D torch.Tensor. It also contains -1 and 0.
+
+ fsa = k2.linear_fsa(labels)
+ fsa.aux_labels = aux_labels
+ # Caution: fsa.scores are all 0s.
+ # `fsa` has only one extra attribute: aux_labels.
+ return Nbest(fsa=fsa, shape=utt_to_path_shape)
+
+ def intersect(self, lattice: k2.Fsa, use_double_scores=True) -> "Nbest":
+ """Intersect this Nbest object with a lattice, get 1-best
+ path from the resulting FsaVec, and return a new Nbest object.
+
+ The purpose of this function is to attach scores to an Nbest.
+
+ Args:
+ lattice:
+ An FsaVec with axes [utt][state][arc]. If it has `aux_labels`, then
+ we assume its `labels` are token IDs and `aux_labels` are word IDs.
+ If it has only `labels`, we assume its `labels` are word IDs.
+ use_double_scores:
+ True to use double precision when computing shortest path.
+ False to use single precision.
+ Returns:
+ Return a new Nbest. This new Nbest shares the same shape with `self`,
+ while its `fsa` is the 1-best path from intersecting `self.fsa` and
+ `lattice`. Also, its `fsa` has non-zero scores and inherits attributes
+ for `lattice`.
+ """
+ # Note: We view each linear FSA as a word sequence
+ # and we use the passed lattice to give each word sequence a score.
+ #
+ # We are not viewing each linear FSAs as a token sequence.
+ #
+ # So we use k2.invert() here.
+
+ # We use a word fsa to intersect with k2.invert(lattice)
+ word_fsa = k2.invert(self.fsa)
+
+ if hasattr(lattice, "aux_labels"):
+ # delete token IDs as it is not needed
+ del word_fsa.aux_labels
+
+ word_fsa.scores.zero_()
+ word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops(
+ word_fsa
+ )
+
+ path_to_utt_map = self.shape.row_ids(1)
+
+ if hasattr(lattice, "aux_labels"):
+ # lattice has token IDs as labels and word IDs as aux_labels.
+ # inv_lattice has word IDs as labels and token IDs as aux_labels
+ inv_lattice = k2.invert(lattice)
+ inv_lattice = k2.arc_sort(inv_lattice)
+ else:
+ inv_lattice = k2.arc_sort(lattice)
+
+ if inv_lattice.shape[0] == 1:
+ path_lattice = _intersect_device(
+ inv_lattice,
+ word_fsa_with_epsilon_loops,
+ b_to_a_map=torch.zeros_like(path_to_utt_map),
+ sorted_match_a=True,
+ )
+ else:
+ path_lattice = _intersect_device(
+ inv_lattice,
+ word_fsa_with_epsilon_loops,
+ b_to_a_map=path_to_utt_map,
+ sorted_match_a=True,
+ )
+
+ # path_lattice has word IDs as labels and token IDs as aux_labels
+ path_lattice = k2.top_sort(k2.connect(path_lattice))
+
+ one_best = k2.shortest_path(
+ path_lattice, use_double_scores=use_double_scores
+ )
+
+ one_best = k2.invert(one_best)
+ # Now one_best has token IDs as labels and word IDs as aux_labels
+
+ return Nbest(fsa=one_best, shape=self.shape)
+
+ def compute_am_scores(self) -> k2.RaggedTensor:
+ """Compute AM scores of each linear FSA (i.e., each path within
+ an utterance).
+
+ Hint:
+ `self.fsa.scores` contains two parts: acoustic scores (AM scores)
+ and n-gram language model scores (LM scores).
+
+ Caution:
+ We require that ``self.fsa`` has an attribute ``lm_scores``.
+
+ Returns:
+ Return a ragged tensor with 2 axes [utt][path_scores].
+ Its dtype is torch.float64.
+ """
+ saved_scores = self.fsa.scores
+
+ # The `scores` of every arc consists of `am_scores` and `lm_scores`
+ self.fsa.scores = self.fsa.scores - self.fsa.lm_scores
+
+ am_scores = self.fsa.get_tot_scores(
+ use_double_scores=True, log_semiring=False
+ )
+ self.fsa.scores = saved_scores
+
+ return k2.RaggedTensor(self.shape, am_scores)
+
+ def compute_lm_scores(self) -> k2.RaggedTensor:
+ """Compute LM scores of each linear FSA (i.e., each path within
+ an utterance).
+
+ Hint:
+ `self.fsa.scores` contains two parts: acoustic scores (AM scores)
+ and n-gram language model scores (LM scores).
+
+ Caution:
+ We require that ``self.fsa`` has an attribute ``lm_scores``.
+
+ Returns:
+ Return a ragged tensor with 2 axes [utt][path_scores].
+ Its dtype is torch.float64.
+ """
+ saved_scores = self.fsa.scores
+
+ # The `scores` of every arc consists of `am_scores` and `lm_scores`
+ self.fsa.scores = self.fsa.lm_scores
+
+ lm_scores = self.fsa.get_tot_scores(
+ use_double_scores=True, log_semiring=False
+ )
+ self.fsa.scores = saved_scores
+
+ return k2.RaggedTensor(self.shape, lm_scores)
+
+ def tot_scores(self) -> k2.RaggedTensor:
+ """Get total scores of FSAs in this Nbest.
+
+ Note:
+ Since FSAs in Nbest are just linear FSAs, log-semiring
+ and tropical semiring produce the same total scores.
+
+ Returns:
+ Return a ragged tensor with two axes [utt][path_scores].
+ Its dtype is torch.float64.
+ """
+ scores = self.fsa.get_tot_scores(
+ use_double_scores=True, log_semiring=False
+ )
+ return k2.RaggedTensor(self.shape, scores)
+
+ def build_levenshtein_graphs(self) -> k2.Fsa:
+ """Return an FsaVec with axes [utt][state][arc]."""
+ word_ids = get_texts(self.fsa, return_ragged=True)
+ return k2.levenshtein_graph(word_ids)
+
+
def one_best_decoding(
- lattice: k2.Fsa, use_double_scores: bool = True
+ lattice: k2.Fsa,
+ use_double_scores: bool = True,
) -> k2.Fsa:
"""Get the best path from a lattice.
@@ -129,222 +443,179 @@ def one_best_decoding(
def nbest_decoding(
- lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True
+ lattice: k2.Fsa,
+ num_paths: int,
+ use_double_scores: bool = True,
+ lattice_score_scale: float = 1.0,
) -> k2.Fsa:
"""It implements something like CTC prefix beam search using n-best lists.
- The basic idea is to first extra n-best paths from the given lattice,
- build a word seqs from these paths, and compute the total scores
- of these sequences in the log-semiring. The one with the max score
+ The basic idea is to first extract `num_paths` paths from the given lattice,
+ build a word sequence from these paths, and compute the total scores
+ of the word sequence in the tropical semiring. The one with the max score
is used as the decoding output.
Caution:
Don't be confused by `best` in the name `n-best`. Paths are selected
- randomly, not by ranking their scores.
+ **randomly**, not by ranking their scores.
+
+ Hint:
+ This decoding method is for demonstration only and it does
+ not produce a lower WER than :func:`one_best_decoding`.
Args:
lattice:
- The decoding lattice, returned by :func:`get_lattice`.
+ The decoding lattice, e.g., can be the return value of
+ :func:`get_lattice`. It has 3 axes [utt][state][arc].
num_paths:
It specifies the size `n` in n-best. Note: Paths are selected randomly
- and those containing identical word sequences are remove dand only one
+ and those containing identical word sequences are removed and only one
of them is kept.
use_double_scores:
True to use double precision floating point in the computation.
False to use single precision.
+ lattice_score_scale:
+ It's the scale applied to the `lattice.scores`. A smaller value
+ leads to more unique paths at the risk of missing the correct path.
Returns:
- An FsaVec containing linear FSAs.
+ An FsaVec containing **linear** FSAs. It axes are [utt][state][arc].
"""
- # First, extract `num_paths` paths for each sequence.
- # path is a k2.RaggedInt with axes [seq][path][arc_pos]
- path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
-
- # word_seq is a k2.RaggedInt sharing the same shape as `path`
- # but it contains word IDs. Note that it also contains 0s and -1s.
- # The last entry in each sublist is -1.
- word_seq = k2.index(lattice.aux_labels, path)
- # Note: the above operation supports also the case when
- # lattice.aux_labels is a ragged tensor. In that case,
- # `remove_axis=True` is used inside the pybind11 binding code,
- # so the resulting `word_seq` still has 3 axes, like `path`.
- # The 3 axes are [seq][path][word_id]
-
- # Remove 0 (epsilon) and -1 from word_seq
- word_seq = k2.ragged.remove_values_leq(word_seq, 0)
-
- # Remove sequences with identical word sequences.
- #
- # k2.ragged.unique_sequences will reorder paths within a seq.
- # `new2old` is a 1-D torch.Tensor mapping from the output path index
- # to the input path index.
- # new2old.numel() == unique_word_seqs.tot_size(1)
- unique_word_seq, _, new2old = k2.ragged.unique_sequences(
- word_seq, need_num_repeats=False, need_new2old_indexes=True
+ nbest = Nbest.from_lattice(
+ lattice=lattice,
+ num_paths=num_paths,
+ use_double_scores=use_double_scores,
+ lattice_score_scale=lattice_score_scale,
)
- # Note: unique_word_seq still has the same axes as word_seq
+ # nbest.fsa.scores contains 0s
- seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0)
+ nbest = nbest.intersect(lattice)
+ # now nbest.fsa.scores gets assigned
- # path_to_seq_map is a 1-D torch.Tensor.
- # path_to_seq_map[i] is the seq to which the i-th path belongs
- path_to_seq_map = seq_to_path_shape.row_ids(1)
+ # max_indexes contains the indexes for the path with the maximum score
+ # within an utterance.
+ max_indexes = nbest.tot_scores().argmax()
- # Remove the seq axis.
- # Now unique_word_seq has only two axes [path][word]
- unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0)
-
- # word_fsa is an FsaVec with axes [path][state][arc]
- word_fsa = k2.linear_fsa(unique_word_seq)
-
- # add epsilon self loops since we will use
- # k2.intersect_device, which treats epsilon as a normal symbol
- word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
-
- # lattice has token IDs as labels and word IDs as aux_labels.
- # inv_lattice has word IDs as labels and token IDs as aux_labels
- inv_lattice = k2.invert(lattice)
- inv_lattice = k2.arc_sort(inv_lattice)
-
- path_lattice = _intersect_device(
- inv_lattice,
- word_fsa_with_epsilon_loops,
- b_to_a_map=path_to_seq_map,
- sorted_match_a=True,
- )
- # path_lat has word IDs as labels and token IDs as aux_labels
-
- path_lattice = k2.top_sort(k2.connect(path_lattice))
-
- tot_scores = path_lattice.get_tot_scores(
- use_double_scores=use_double_scores, log_semiring=False
- )
-
- # RaggedFloat currently supports float32 only.
- # If Ragged is wrapped, we can use k2.RaggedDouble here
- ragged_tot_scores = k2.RaggedFloat(
- seq_to_path_shape, tot_scores.to(torch.float32)
- )
-
- argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
-
- # Since we invoked `k2.ragged.unique_sequences`, which reorders
- # the index from `path`, we use `new2old` here to convert argmax_indexes
- # to the indexes into `path`.
- #
- # Use k2.index here since argmax_indexes' dtype is torch.int32
- best_path_indexes = k2.index(new2old, argmax_indexes)
-
- path_2axes = k2.ragged.remove_axis(path, 0)
-
- # best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
- best_path = k2.index(path_2axes, best_path_indexes)
-
- # labels is a k2.RaggedInt with 2 axes [path][token_id]
- # Note that it contains -1s.
- labels = k2.index(lattice.labels.contiguous(), best_path)
-
- labels = k2.ragged.remove_values_eq(labels, -1)
-
- # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so
- # aux_labels is also a k2.RaggedInt with 2 axes
- aux_labels = k2.index(lattice.aux_labels, best_path.values())
-
- best_path_fsa = k2.linear_fsa(labels)
- best_path_fsa.aux_labels = aux_labels
- return best_path_fsa
+ best_path = k2.index_fsa(nbest.fsa, max_indexes)
+ return best_path
-def compute_am_and_lm_scores(
+def nbest_oracle(
lattice: k2.Fsa,
- word_fsa_with_epsilon_loops: k2.Fsa,
- path_to_seq_map: torch.Tensor,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- """Compute AM scores of n-best lists (represented as word_fsas).
+ num_paths: int,
+ ref_texts: List[str],
+ word_table: k2.SymbolTable,
+ use_double_scores: bool = True,
+ lattice_score_scale: float = 0.5,
+ oov: str = "",
+) -> Dict[str, List[List[int]]]:
+ """Select the best hypothesis given a lattice and a reference transcript.
+
+ The basic idea is to extract `num_paths` paths from the given lattice,
+ unique them, and select the one that has the minimum edit distance with
+ the corresponding reference transcript as the decoding output.
+
+ The decoding result returned from this function is the best result that
+ we can obtain using n-best decoding with all kinds of rescoring techniques.
+
+ This function is useful to tune the value of `lattice_score_scale`.
Args:
lattice:
- An FsaVec, e.g., the return value of :func:`get_lattice`
- It must have the attribute `lm_scores`.
- word_fsa_with_epsilon_loops:
- An FsaVec representing an n-best list. Note that it has been processed
- by `k2.add_epsilon_self_loops`.
- path_to_seq_map:
- A 1-D torch.Tensor with dtype torch.int32. path_to_seq_map[i] indicates
- which sequence the i-th Fsa in word_fsa_with_epsilon_loops belongs to.
- path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0().
- Returns:
- Return a tuple containing two 1-D torch.Tensors: (am_scores, lm_scores).
- Each tensor's `numel()' equals to `word_fsas_with_epsilon_loops.shape[0]`
+ An FsaVec with axes [utt][state][arc].
+ Note: We assume its `aux_labels` contains word IDs.
+ num_paths:
+ The size of `n` in n-best.
+ ref_texts:
+ A list of reference transcript. Each entry contains space(s)
+ separated words
+ word_table:
+ It is the word symbol table.
+ use_double_scores:
+ True to use double precision for computation. False to use
+ single precision.
+ lattice_score_scale:
+ It's the scale applied to the lattice.scores. A smaller value
+ yields more unique paths.
+ oov:
+ The out of vocabulary word.
+ Return:
+ Return a dict. Its key contains the information about the parameters
+ when calling this function, while its value contains the decoding output.
+ `len(ans_dict) == len(ref_texts)`
"""
- assert len(lattice.shape) == 3
- assert hasattr(lattice, "lm_scores")
+ device = lattice.device
- # k2.compose() currently does not support b_to_a_map. To void
- # replicating `lats`, we use k2.intersect_device here.
- #
- # lattice has token IDs as `labels` and word IDs as aux_labels, so we
- # need to invert it here.
- inv_lattice = k2.invert(lattice)
-
- # Now the `labels` of inv_lattice are word IDs (a 1-D torch.Tensor)
- # and its `aux_labels` are token IDs ( a k2.RaggedInt with 2 axes)
-
- # Remove its `aux_labels` since it is not needed in the
- # following computation
- del inv_lattice.aux_labels
- inv_lattice = k2.arc_sort(inv_lattice)
-
- path_lattice = _intersect_device(
- inv_lattice,
- word_fsa_with_epsilon_loops,
- b_to_a_map=path_to_seq_map,
- sorted_match_a=True,
+ nbest = Nbest.from_lattice(
+ lattice=lattice,
+ num_paths=num_paths,
+ use_double_scores=use_double_scores,
+ lattice_score_scale=lattice_score_scale,
)
- path_lattice = k2.top_sort(k2.connect(path_lattice))
+ hyps = nbest.build_levenshtein_graphs()
- # The `scores` of every arc consists of `am_scores` and `lm_scores`
- path_lattice.scores = path_lattice.scores - path_lattice.lm_scores
+ oov_id = word_table[oov]
+ word_ids_list = []
+ for text in ref_texts:
+ word_ids = []
+ for word in text.split():
+ if word in word_table:
+ word_ids.append(word_table[word])
+ else:
+ word_ids.append(oov_id)
+ word_ids_list.append(word_ids)
- am_scores = path_lattice.get_tot_scores(
- use_double_scores=True, log_semiring=False
+ refs = k2.levenshtein_graph(word_ids_list, device=device)
+
+ levenshtein_alignment = k2.levenshtein_alignment(
+ refs=refs,
+ hyps=hyps,
+ hyp_to_ref_map=nbest.shape.row_ids(1),
+ sorted_match_ref=True,
)
- path_lattice.scores = path_lattice.lm_scores
-
- lm_scores = path_lattice.get_tot_scores(
- use_double_scores=True, log_semiring=False
+ tot_scores = levenshtein_alignment.get_tot_scores(
+ use_double_scores=False, log_semiring=False
)
+ ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
- return am_scores.to(torch.float32), lm_scores.to(torch.float32)
+ max_indexes = ragged_tot_scores.argmax()
+
+ best_path = k2.index_fsa(nbest.fsa, max_indexes)
+ return best_path
def rescore_with_n_best_list(
- lattice: k2.Fsa, G: k2.Fsa, num_paths: int, lm_scale_list: List[float]
+ lattice: k2.Fsa,
+ G: k2.Fsa,
+ num_paths: int,
+ lm_scale_list: List[float],
+ lattice_score_scale: float = 1.0,
+ use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]:
- """Decode using n-best list with LM rescoring.
-
- `lattice` is a decoding lattice with 3 axes. This function first
- extracts `num_paths` paths from `lattice` for each sequence using
- `k2.random_paths`. The `am_scores` of these paths are computed.
- For each path, its `lm_scores` is computed using `G` (which is an LM).
- The final `tot_scores` is the sum of `am_scores` and `lm_scores`.
- The path with the largest `tot_scores` within a sequence is used
- as the decoding output.
+ """Rescore an n-best list with an n-gram LM.
+ The path with the maximum score is used as the decoding output.
Args:
lattice:
- An FsaVec. It can be the return value of :func:`get_lattice`.
+ An FsaVec with axes [utt][state][arc]. It must have the following
+ attributes: ``aux_labels`` and ``lm_scores``. Its labels are
+ token IDs and ``aux_labels`` word IDs.
G:
- An FsaVec representing the language model (LM). Note that it
- is an FsaVec, but it contains only one Fsa.
+ An FsaVec containing only a single FSA. It is an n-gram LM.
num_paths:
- It is the size `n` in `n-best` list.
+ Size of nbest list.
lm_scale_list:
- A list containing lm_scale values.
+ A list of float representing LM score scales.
+ lattice_score_scale:
+ Scale to be applied to ``lattice.score`` when sampling paths
+ using ``k2.random_paths``.
+ use_double_scores:
+ True to use double precision during computation. False to use
+ single precision.
Returns:
A dict of FsaVec, whose key is an lm_scale and the value is the
- best decoding path for each sequence in the lattice.
+ best decoding path for each utterance in the lattice.
"""
device = lattice.device
@@ -356,109 +627,32 @@ def rescore_with_n_best_list(
assert G.device == device
assert hasattr(G, "aux_labels") is False
- # First, extract `num_paths` paths for each sequence.
- # path is a k2.RaggedInt with axes [seq][path][arc_pos]
- path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
-
- # word_seq is a k2.RaggedInt sharing the same shape as `path`
- # but it contains word IDs. Note that it also contains 0s and -1s.
- # The last entry in each sublist is -1.
- word_seq = k2.index(lattice.aux_labels, path)
-
- # Remove epsilons and -1 from word_seq
- word_seq = k2.ragged.remove_values_leq(word_seq, 0)
-
- # Remove paths that has identical word sequences.
- #
- # unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word]
- # except that there are no repeated paths with the same word_seq
- # within a sequence.
- #
- # num_repeats is also a k2.RaggedInt with 2 axes containing the
- # multiplicities of each path.
- # num_repeats.num_elements() == unique_word_seqs.num_elements()
- #
- # Since k2.ragged.unique_sequences will reorder paths within a seq,
- # `new2old` is a 1-D torch.Tensor mapping from the output path index
- # to the input path index.
- # new2old.numel() == unique_word_seqs.tot_size(1)
- unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
- word_seq, need_num_repeats=True, need_new2old_indexes=True
+ nbest = Nbest.from_lattice(
+ lattice=lattice,
+ num_paths=num_paths,
+ use_double_scores=use_double_scores,
+ lattice_score_scale=lattice_score_scale,
)
+ # nbest.fsa.scores are all 0s at this point
- seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0)
+ nbest = nbest.intersect(lattice)
+ # Now nbest.fsa has its scores set
+ assert hasattr(nbest.fsa, "lm_scores")
- # path_to_seq_map is a 1-D torch.Tensor.
- # path_to_seq_map[i] is the seq to which the i-th path
- # belongs.
- path_to_seq_map = seq_to_path_shape.row_ids(1)
+ am_scores = nbest.compute_am_scores()
- # Remove the seq axis.
- # Now unique_word_seq has only two axes [path][word]
- unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0)
-
- # word_fsa is an FsaVec with axes [path][state][arc]
- word_fsa = k2.linear_fsa(unique_word_seq)
-
- word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
-
- am_scores, _ = compute_am_and_lm_scores(
- lattice, word_fsa_with_epsilon_loops, path_to_seq_map
- )
-
- # Now compute lm_scores
- b_to_a_map = torch.zeros_like(path_to_seq_map)
- lm_path_lattice = _intersect_device(
- G,
- word_fsa_with_epsilon_loops,
- b_to_a_map=b_to_a_map,
- sorted_match_a=True,
- )
- lm_path_lattice = k2.top_sort(k2.connect(lm_path_lattice))
- lm_scores = lm_path_lattice.get_tot_scores(
- use_double_scores=True, log_semiring=False
- )
-
- path_2axes = k2.ragged.remove_axis(path, 0)
+ nbest = nbest.intersect(G)
+ # Now nbest contains only lm scores
+ lm_scores = nbest.tot_scores()
ans = dict()
for lm_scale in lm_scale_list:
- tot_scores = am_scores / lm_scale + lm_scores
-
- # Remember that we used `k2.ragged.unique_sequences` to remove repeated
- # paths to avoid redundant computation in `k2.intersect_device`.
- # Now we use `num_repeats` to correct the scores for each path.
- #
- # NOTE(fangjun): It is commented out as it leads to a worse WER
- # tot_scores = tot_scores * num_repeats.values()
-
- ragged_tot_scores = k2.RaggedFloat(
- seq_to_path_shape, tot_scores.to(torch.float32)
- )
- argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
-
- # Use k2.index here since argmax_indexes' dtype is torch.int32
- best_path_indexes = k2.index(new2old, argmax_indexes)
-
- # best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
- best_path = k2.index(path_2axes, best_path_indexes)
-
- # labels is a k2.RaggedInt with 2 axes [path][phone_id]
- # Note that it contains -1s.
- labels = k2.index(lattice.labels.contiguous(), best_path)
-
- labels = k2.ragged.remove_values_eq(labels, -1)
-
- # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so
- # aux_labels is also a k2.RaggedInt with 2 axes
- aux_labels = k2.index(lattice.aux_labels, best_path.values())
-
- best_path_fsa = k2.linear_fsa(labels)
- best_path_fsa.aux_labels = aux_labels
-
+ tot_scores = am_scores.values / lm_scale + lm_scores.values
+ tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
+ max_indexes = tot_scores.argmax()
+ best_path = k2.index_fsa(nbest.fsa, max_indexes)
key = f"lm_scale_{lm_scale}"
- ans[key] = best_path_fsa
-
+ ans[key] = best_path
return ans
@@ -466,25 +660,40 @@ def rescore_with_whole_lattice(
lattice: k2.Fsa,
G_with_epsilon_loops: k2.Fsa,
lm_scale_list: Optional[List[float]] = None,
+ use_double_scores: bool = True,
) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
- """Use whole lattice to rescore.
+ """Intersect the lattice with an n-gram LM and use shortest path
+ to decode.
+
+ The input lattice is obtained by intersecting `HLG` with
+ a DenseFsaVec, where the `G` in `HLG` is in general a 3-gram LM.
+ The input `G_with_epsilon_loops` is usually a 4-gram LM. You can consider
+ this function as a second pass decoding. In the first pass decoding, we
+ use a small G, while we use a larger G in the second pass decoding.
Args:
lattice:
- An FsaVec It can be the return value of :func:`get_lattice`.
+ An FsaVec with axes [utt][state][arc]. Its `aux_lables` are word IDs.
+ It must have an attribute `lm_scores`.
G_with_epsilon_loops:
- An FsaVec representing the language model (LM). Note that it
- is an FsaVec, but it contains only one Fsa.
+ An FsaVec containing only a single FSA. It contains epsilon self-loops.
+ It is an acceptor and its labels are word IDs.
lm_scale_list:
- A list containing lm_scale values or None.
+ Optional. If none, return the intersection of `lattice` and
+ `G_with_epsilon_loops`.
+ If not None, it contains a list of values to scale LM scores.
+ For each scale, there is a corresponding decoding result contained in
+ the resulting dict.
+ use_double_scores:
+ True to use double precision in the computation.
+ False to use single precision.
Returns:
- If lm_scale_list is not None, return a dict of FsaVec, whose key
- is a lm_scale and the value represents the best decoding path for
- each sequence in the lattice.
- If lm_scale_list is not None, return a lattice that is rescored
- with the given LM.
+ If `lm_scale_list` is None, return a new lattice which is the intersection
+ result of `lattice` and `G_with_epsilon_loops`.
+ Otherwise, return a dict whose key is an entry in `lm_scale_list` and the
+ value is the decoding result (i.e., an FsaVec containing linear FSAs).
"""
- assert len(lattice.shape) == 3
+ # Nbest is not used in this function
assert hasattr(lattice, "lm_scores")
assert G_with_epsilon_loops.shape == (1, None, None)
@@ -492,17 +701,22 @@ def rescore_with_whole_lattice(
lattice.scores = lattice.scores - lattice.lm_scores
# We will use lm_scores from G, so remove lats.lm_scores here
del lattice.lm_scores
- assert hasattr(lattice, "lm_scores") is False
+
+ assert hasattr(G_with_epsilon_loops, "lm_scores")
# Now, lattice.scores contains only am_scores
# inv_lattice has word IDs as labels.
- # Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt
+ # Its `aux_labels` is token IDs
inv_lattice = k2.invert(lattice)
num_seqs = lattice.shape[0]
b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
- while True:
+
+ max_loop_count = 10
+ loop_count = 0
+ while loop_count <= max_loop_count:
+ loop_count += 1
try:
rescoring_lattice = k2.intersect_device(
G_with_epsilon_loops,
@@ -518,12 +732,15 @@ def rescore_with_whole_lattice(
f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
)
- # NOTE(fangjun): The choice of the threshold 1e-7 is arbitrary here
- # to avoid OOM. We may need to fine tune it.
- inv_lattice = k2.prune_on_arc_post(inv_lattice, 1e-7, True)
+ # NOTE(fangjun): The choice of the threshold 1e-9 is arbitrary here
+ # to avoid OOM. You may need to fine tune it.
+ inv_lattice = k2.prune_on_arc_post(inv_lattice, 1e-9, True)
logging.info(
f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
)
+ if loop_count > max_loop_count:
+ logging.info("Return None as the resulting lattice is too large")
+ return None
# lat has token IDs as labels
# and word IDs as aux_labels.
@@ -533,17 +750,12 @@ def rescore_with_whole_lattice(
return lat
ans = dict()
- #
- # The following implements
- # scores = (scores - lm_scores)/lm_scale + lm_scores
- # = scores/lm_scale + lm_scores*(1 - 1/lm_scale)
- #
saved_am_scores = lat.scores - lat.lm_scores
for lm_scale in lm_scale_list:
am_scores = saved_am_scores / lm_scale
lat.scores = am_scores + lat.lm_scores
- best_path = k2.shortest_path(lat, use_double_scores=True)
+ best_path = k2.shortest_path(lat, use_double_scores=use_double_scores)
key = f"lm_scale_{lm_scale}"
ans[key] = best_path
return ans
@@ -552,19 +764,23 @@ def rescore_with_whole_lattice(
def rescore_with_attention_decoder(
lattice: k2.Fsa,
num_paths: int,
- model: nn.Module,
+ model: torch.nn.Module,
memory: torch.Tensor,
- memory_key_padding_mask: torch.Tensor,
+ memory_key_padding_mask: Optional[torch.Tensor],
sos_id: int,
eos_id: int,
+ lattice_score_scale: float = 1.0,
+ ngram_lm_scale: Optional[float] = None,
+ attention_scale: Optional[float] = None,
+ use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]:
- """This function extracts n paths from the given lattice and uses
- an attention decoder to rescore them. The path with the highest
- score is used as the decoding output.
+ """This function extracts `num_paths` paths from the given lattice and uses
+ an attention decoder to rescore them. The path with the highest score is
+ the decoding output.
Args:
lattice:
- An FsaVec. It can be the return value of :func:`get_lattice`.
+ An FsaVec with axes [utt][state][arc].
num_paths:
Number of paths to extract from the given lattice for rescoring.
model:
@@ -573,100 +789,64 @@ def rescore_with_attention_decoder(
memory:
The encoder memory of the given model. It is the output of
the last torch.nn.TransformerEncoder layer in the given model.
- Its shape is `[T, N, C]`.
+ Its shape is `(T, N, C)`.
memory_key_padding_mask:
- The padding mask for memory with shape [N, T].
+ The padding mask for memory with shape `(N, T)`.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
+ lattice_score_scale:
+ It's the scale applied to `lattice.scores`. A smaller value
+ leads to more unique paths at the risk of missing the correct path.
+ ngram_lm_scale:
+ Optional. It specifies the scale for n-gram LM scores.
+ attention_scale:
+ Optional. It specifies the scale for attention decoder scores.
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
- best decoding path for each sequence in the lattice.
+ best decoding path for each utterance in the lattice.
"""
- # First, extract `num_paths` paths for each sequence.
- # path is a k2.RaggedInt with axes [seq][path][arc_pos]
- path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
-
- # word_seq is a k2.RaggedInt sharing the same shape as `path`
- # but it contains word IDs. Note that it also contains 0s and -1s.
- # The last entry in each sublist is -1.
- word_seq = k2.index(lattice.aux_labels, path)
-
- # Remove epsilons and -1 from word_seq
- word_seq = k2.ragged.remove_values_leq(word_seq, 0)
-
- # Remove paths that has identical word sequences.
- #
- # unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word]
- # except that there are no repeated paths with the same word_seq
- # within a sequence.
- #
- # num_repeats is also a k2.RaggedInt with 2 axes containing the
- # multiplicities of each path.
- # num_repeats.num_elements() == unique_word_seqs.num_elements()
- #
- # Since k2.ragged.unique_sequences will reorder paths within a seq,
- # `new2old` is a 1-D torch.Tensor mapping from the output path index
- # to the input path index.
- # new2old.numel() == unique_word_seqs.tot_size(1)
- unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
- word_seq, need_num_repeats=True, need_new2old_indexes=True
+ nbest = Nbest.from_lattice(
+ lattice=lattice,
+ num_paths=num_paths,
+ use_double_scores=use_double_scores,
+ lattice_score_scale=lattice_score_scale,
)
+ # nbest.fsa.scores are all 0s at this point
- seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0)
+ nbest = nbest.intersect(lattice)
+ # Now nbest.fsa has its scores set.
+ # Also, nbest.fsa inherits the attributes from `lattice`.
+ assert hasattr(nbest.fsa, "lm_scores")
- # path_to_seq_map is a 1-D torch.Tensor.
- # path_to_seq_map[i] is the seq to which the i-th path
- # belongs.
- path_to_seq_map = seq_to_path_shape.row_ids(1)
+ am_scores = nbest.compute_am_scores()
+ ngram_lm_scores = nbest.compute_lm_scores()
- # Remove the seq axis.
- # Now unique_word_seq has only two axes [path][word]
- unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0)
+ # The `tokens` attribute is set inside `compile_hlg.py`
+ assert hasattr(nbest.fsa, "tokens")
+ assert isinstance(nbest.fsa.tokens, torch.Tensor)
- # word_fsa is an FsaVec with axes [path][state][arc]
- word_fsa = k2.linear_fsa(unique_word_seq)
+ path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
+ # the shape of memory is (T, N, C), so we use axis=1 here
+ expanded_memory = memory.index_select(1, path_to_utt_map)
- word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
+ if memory_key_padding_mask is not None:
+ # The shape of memory_key_padding_mask is (N, T), so we
+ # use axis=0 here.
+ expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
+ 0, path_to_utt_map
+ )
+ else:
+ expanded_memory_key_padding_mask = None
- am_scores, ngram_lm_scores = compute_am_and_lm_scores(
- lattice, word_fsa_with_epsilon_loops, path_to_seq_map
- )
- # Now we use the attention decoder to compute another
- # score: attention_scores.
- #
- # To do that, we have to get the input and output for the attention
- # decoder.
+ # remove axis corresponding to states.
+ tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
+ tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
+ tokens = tokens.remove_values_leq(0)
+ token_ids = tokens.tolist()
- # CAUTION: The "tokens" attribute is set in the file
- # local/compile_hlg.py
- token_seq = k2.index(lattice.tokens, path)
-
- # Remove epsilons and -1 from token_seq
- token_seq = k2.ragged.remove_values_leq(token_seq, 0)
-
- # Remove the seq axis.
- token_seq = k2.ragged.remove_axis(token_seq, 0)
-
- token_seq, _ = k2.ragged.index(
- token_seq, indexes=new2old, axis=0, need_value_indexes=False
- )
-
- # Now word in unique_word_seq has its corresponding token IDs.
- token_ids = k2.ragged.to_list(token_seq)
-
- num_word_seqs = new2old.numel()
-
- path_to_seq_map_long = path_to_seq_map.to(torch.long)
- expanded_memory = memory.index_select(1, path_to_seq_map_long)
-
- expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
- 0, path_to_seq_map_long
- )
-
- # TODO: pass the sos_token_id and eos_token_id via function arguments
nll = model.decoder_nll(
memory=expanded_memory,
memory_key_padding_mask=expanded_memory_key_padding_mask,
@@ -675,49 +855,36 @@ def rescore_with_attention_decoder(
eos_id=eos_id,
)
assert nll.ndim == 2
- assert nll.shape[0] == num_word_seqs
+ assert nll.shape[0] == len(token_ids)
attention_scores = -nll.sum(dim=1)
- assert attention_scores.ndim == 1
- assert attention_scores.numel() == num_word_seqs
- ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
- ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
+ if ngram_lm_scale is None:
+ ngram_lm_scale_list = [0.01, 0.05, 0.08]
+ ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
+ ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
+ else:
+ ngram_lm_scale_list = [ngram_lm_scale]
- attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
- attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
-
- path_2axes = k2.ragged.remove_axis(path, 0)
+ if attention_scale is None:
+ attention_scale_list = [0.01, 0.05, 0.08]
+ attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
+ attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
+ else:
+ attention_scale_list = [attention_scale]
ans = dict()
for n_scale in ngram_lm_scale_list:
for a_scale in attention_scale_list:
tot_scores = (
- am_scores
- + n_scale * ngram_lm_scores
+ am_scores.values
+ + n_scale * ngram_lm_scores.values
+ a_scale * attention_scores
)
- ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores)
- argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
-
- best_path_indexes = k2.index(new2old, argmax_indexes)
-
- # best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
- best_path = k2.index(path_2axes, best_path_indexes)
-
- # labels is a k2.RaggedInt with 2 axes [path][token_id]
- # Note that it contains -1s.
- labels = k2.index(lattice.labels.contiguous(), best_path)
-
- labels = k2.ragged.remove_values_eq(labels, -1)
-
- # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so
- # aux_labels is also a k2.RaggedInt with 2 axes
- aux_labels = k2.index(lattice.aux_labels, best_path.values())
-
- best_path_fsa = k2.linear_fsa(labels)
- best_path_fsa.aux_labels = aux_labels
+ ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
+ max_indexes = ragged_tot_scores.argmax()
+ best_path = k2.index_fsa(nbest.fsa, max_indexes)
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
- ans[key] = best_path_fsa
+ ans[key] = best_path
return ans
diff --git a/icefall/dist.py b/icefall/dist.py
index d314d2a43..203c7c563 100644
--- a/icefall/dist.py
+++ b/icefall/dist.py
@@ -1,3 +1,20 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
import os
import torch
diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py
index f7ba3cdaf..b4c87d964 100644
--- a/icefall/graph_compiler.py
+++ b/icefall/graph_compiler.py
@@ -1,3 +1,20 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
from typing import List
import k2
@@ -8,7 +25,10 @@ from icefall.lexicon import Lexicon
class CtcTrainingGraphCompiler(object):
def __init__(
- self, lexicon: Lexicon, device: torch.device, oov: str = "",
+ self,
+ lexicon: Lexicon,
+ device: torch.device,
+ oov: str = "",
):
"""
Args:
@@ -86,7 +106,7 @@ class CtcTrainingGraphCompiler(object):
word_ids_list = []
for text in texts:
word_ids = []
- for word in text.split(" "):
+ for word in text.split():
if word in self.word_table:
word_ids.append(self.word_table[word])
else:
diff --git a/icefall/lexicon.py b/icefall/lexicon.py
index 43a0fda37..1378d79fb 100644
--- a/icefall/lexicon.py
+++ b/icefall/lexicon.py
@@ -1,3 +1,20 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
import logging
import re
import sys
@@ -142,7 +159,7 @@ class BpeLexicon(Lexicon):
lang_dir / "lexicon.txt"
)
- def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedInt:
+ def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedTensor:
"""Read a BPE lexicon from file and convert it to a
k2 ragged tensor.
@@ -185,19 +202,18 @@ class BpeLexicon(Lexicon):
)
values = torch.tensor(token_ids, dtype=torch.int32)
- return k2.RaggedInt(shape, values)
+ return k2.RaggedTensor(shape, values)
- def words_to_piece_ids(self, words: List[str]) -> k2.RaggedInt:
+ def words_to_piece_ids(self, words: List[str]) -> k2.RaggedTensor:
"""Convert a list of words to a ragged tensor contained
word piece IDs.
"""
word_ids = [self.word_table[w] for w in words]
word_ids = torch.tensor(word_ids, dtype=torch.int32)
- ragged, _ = k2.ragged.index(
- self.ragged_lexicon,
+ ragged, _ = self.ragged_lexicon.index(
indexes=word_ids,
- need_value_indexes=False,
axis=0,
+ need_value_indexes=False,
)
return ragged
diff --git a/icefall/utils.py b/icefall/utils.py
index 3d48badfe..23b4dd6c7 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -1,3 +1,20 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
import argparse
import logging
import os
@@ -9,7 +26,6 @@ from pathlib import Path
from typing import Dict, Iterable, List, TextIO, Tuple, Union
import k2
-import k2.ragged as k2r
import kaldialign
import torch
import torch.distributed as dist
@@ -130,12 +146,20 @@ def get_env_info():
}
-# See
-# https://stackoverflow.com/questions/4984647/accessing-dict-keys-like-an-attribute # noqa
class AttributeDict(dict):
- __slots__ = ()
- __getattr__ = dict.__getitem__
- __setattr__ = dict.__setitem__
+ def __getattr__(self, key):
+ if key in self:
+ return self[key]
+ raise AttributeError(f"No such attribute '{key}'")
+
+ def __setattr__(self, key, value):
+ self[key] = value
+
+ def __delattr__(self, key):
+ if key in self:
+ del self[key]
+ return
+ raise AttributeError(f"No such attribute '{key}'")
def encode_supervisions(
@@ -170,7 +194,9 @@ def encode_supervisions(
return supervision_segments, texts
-def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
+def get_texts(
+ best_paths: k2.Fsa, return_ragged: bool = False
+) -> Union[List[List[int]], k2.RaggedTensor]:
"""Extract the texts (as word IDs) from the best-path FSAs.
Args:
best_paths:
@@ -178,30 +204,35 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't
be meaningful).
+ return_ragged:
+ True to return a ragged tensor with two axes [utt][word_id].
+ False to return a list-of-list word IDs.
Returns:
Returns a list of lists of int, containing the label sequences we
decoded.
"""
- if isinstance(best_paths.aux_labels, k2.RaggedInt):
+ if isinstance(best_paths.aux_labels, k2.RaggedTensor):
# remove 0's and -1's.
- aux_labels = k2r.remove_values_leq(best_paths.aux_labels, 0)
- aux_shape = k2r.compose_ragged_shapes(
- best_paths.arcs.shape(), aux_labels.shape()
- )
+ aux_labels = best_paths.aux_labels.remove_values_leq(0)
+ # TODO: change arcs.shape() to arcs.shape
+ aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
# remove the states and arcs axes.
- aux_shape = k2r.remove_axis(aux_shape, 1)
- aux_shape = k2r.remove_axis(aux_shape, 1)
- aux_labels = k2.RaggedInt(aux_shape, aux_labels.values())
+ aux_shape = aux_shape.remove_axis(1)
+ aux_shape = aux_shape.remove_axis(1)
+ aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
else:
# remove axis corresponding to states.
- aux_shape = k2r.remove_axis(best_paths.arcs.shape(), 1)
- aux_labels = k2.RaggedInt(aux_shape, best_paths.aux_labels)
+ aux_shape = best_paths.arcs.shape().remove_axis(1)
+ aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
# remove 0's and -1's.
- aux_labels = k2r.remove_values_leq(aux_labels, 0)
+ aux_labels = aux_labels.remove_values_leq(0)
- assert aux_labels.num_axes() == 2
- return k2r.to_list(aux_labels)
+ assert aux_labels.num_axes == 2
+ if return_ragged:
+ return aux_labels
+ else:
+ return aux_labels.tolist()
def store_transcripts(
diff --git a/test/test_bpe_graph_compiler.py b/test/test_bpe_graph_compiler.py
index 7b941e5a7..e58c4f1c6 100755
--- a/test/test_bpe_graph_compiler.py
+++ b/test/test_bpe_graph_compiler.py
@@ -1,10 +1,25 @@
#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
-# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
+
+from pathlib import Path
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.lexicon import BpeLexicon
-from pathlib import Path
def test():
@@ -15,7 +30,7 @@ def test():
compiler = BpeCtcTrainingGraphCompiler(lang_dir)
ids = compiler.texts_to_ids(["HELLO", "WORLD ZZZ"])
- fsa = compiler.compile(ids)
+ compiler.compile(ids)
lexicon = BpeLexicon(lang_dir)
ids0 = lexicon.words_to_piece_ids(["HELLO"])
diff --git a/test/test_checkpoint.py b/test/test_checkpoint.py
index 343768957..511a11c23 100644
--- a/test/test_checkpoint.py
+++ b/test/test_checkpoint.py
@@ -1,4 +1,20 @@
#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import pytest
import torch
diff --git a/test/test_decode.py b/test/test_decode.py
new file mode 100644
index 000000000..7ef127781
--- /dev/null
+++ b/test/test_decode.py
@@ -0,0 +1,62 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+You can run this file in one of the two ways:
+
+ (1) cd icefall; pytest test/test_decode.py
+ (2) cd icefall; ./test/test_decode.py
+"""
+
+import k2
+from icefall.decode import Nbest
+
+
+def test_nbest_from_lattice():
+ s = """
+ 0 1 1 10 0.1
+ 0 1 5 10 0.11
+ 0 1 2 20 0.2
+ 1 2 3 30 0.3
+ 1 2 4 40 0.4
+ 2 3 -1 -1 0.5
+ 3
+ """
+ lattice = k2.Fsa.from_str(s, acceptor=False)
+ lattice = k2.Fsa.from_fsas([lattice, lattice])
+
+ nbest = Nbest.from_lattice(
+ lattice=lattice,
+ num_paths=10,
+ use_double_scores=True,
+ lattice_score_scale=0.5,
+ )
+ # each lattice has only 4 distinct paths that have different word sequences:
+ # 10->30
+ # 10->40
+ # 20->30
+ # 20->40
+ #
+ # So there should be only 4 paths for each lattice in the Nbest object
+ assert nbest.fsa.shape[0] == 4 * 2
+ assert nbest.shape.row_splits(1).tolist() == [0, 4, 8]
+
+ nbest2 = nbest.intersect(lattice)
+ tot_scores = nbest2.tot_scores()
+ argmax = tot_scores.argmax()
+ best_path = k2.index_fsa(nbest2.fsa, argmax)
+ print(best_path[0])
diff --git a/test/test_graph_compiler.py b/test/test_graph_compiler.py
index 4083d79ac..ccfb57d49 100644
--- a/test/test_graph_compiler.py
+++ b/test/test_graph_compiler.py
@@ -1,6 +1,20 @@
#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
-# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
import re
diff --git a/test/test_lexicon.py b/test/test_lexicon.py
index b1284d98a..6801b3a89 100644
--- a/test/test_lexicon.py
+++ b/test/test_lexicon.py
@@ -1,4 +1,20 @@
#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from pathlib import Path
diff --git a/test/test_utils.py b/test/test_utils.py
index 27b1ac203..7ac52b289 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -1,4 +1,21 @@
#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
import k2
import pytest
import torch
@@ -43,7 +60,7 @@ def test_get_texts_ragged():
4
"""
)
- fsa1.aux_labels = k2.RaggedInt("[ [1 3 0 2] [] [4 0 1] [-1]]")
+ fsa1.aux_labels = k2.RaggedTensor("[ [1 3 0 2] [] [4 0 1] [-1]]")
fsa2 = k2.Fsa.from_str(
"""
@@ -53,7 +70,7 @@ def test_get_texts_ragged():
3
"""
)
- fsa2.aux_labels = k2.RaggedInt("[[3 0 5 0 8] [0 9 7 0] [-1]]")
+ fsa2.aux_labels = k2.RaggedTensor("[[3 0 5 0 8] [0 9 7 0] [-1]]")
fsas = k2.Fsa.from_fsas([fsa1, fsa2])
texts = get_texts(fsas)
assert texts == [[1, 3, 2, 4, 1], [3, 5, 8, 9, 7]]
@@ -91,3 +108,14 @@ def test_attribute_dict():
assert s["b"] == 20
s.c = 100
assert s["c"] == 100
+ assert hasattr(s, "a")
+ assert hasattr(s, "b")
+ assert getattr(s, "a") == 10
+ del s.a
+ assert hasattr(s, "a") is False
+ setattr(s, "c", 100)
+ s.c = 100
+ try:
+ del s.a
+ except AttributeError as ex:
+ print(f"Caught exception: {ex}")