Merge branch 'master' into embedding-scale

This commit is contained in:
Fangjun Kuang 2021-08-26 14:41:53 +08:00
commit b09224fb3a
82 changed files with 5880 additions and 567 deletions

View File

@ -2,6 +2,9 @@
show-source=true
statistics=true
max-line-length = 80
per-file-ignores =
# line too long
egs/librispeech/ASR/conformer_ctc/conformer.py: E501,
exclude =
.git,

78
.github/workflows/run-yesno-recipe.yml vendored Normal file
View File

@ -0,0 +1,78 @@
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-yesno-recipe
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
run-yesno-recipe:
runs-on: ${{ matrix.os }}
strategy:
matrix:
# os: [ubuntu-18.04, macos-10.15]
# TODO: enable macOS for CPU testing
os: [ubuntu-18.04]
python-version: [3.8]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install libnsdfile and libsox
if: startsWith(matrix.os, 'ubuntu')
run: |
sudo apt update
sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg
sudo apt install -q -y --fix-missing sox libsox-dev libsox-fmt-all
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip black flake8
python3 -m pip install -U pip
python3 -m pip install k2==1.4.dev20210822+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/
python3 -m pip install torchaudio==0.7.2
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
# We are in ./icefall and there is a file: requirements.txt in it
python3 -m pip install -r requirements.txt
- name: Run yesno recipe
shell: bash
working-directory: ${{github.workspace}}
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
echo $PYTHONPATH
cd egs/yesno/ASR
./prepare.sh
python3 ./tdnn/train.py
python3 ./tdnn/decode.py
# TODO: Check that the WER is less than some value

View File

@ -45,7 +45,7 @@ jobs:
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip black flake8
python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2
- name: Run flake8
shell: bash

View File

@ -32,7 +32,7 @@ jobs:
os: [ubuntu-18.04, macos-10.15]
python-version: [3.6, 3.7, 3.8, 3.9]
torch: ["1.8.1"]
k2-version: ["1.2.dev20210724"]
k2-version: ["1.4.dev20210822"]
fail-fast: false
steps:
@ -47,16 +47,10 @@ jobs:
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip pytest kaldialign
python3 -m pip install --upgrade pip pytest
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
# Don't use: pip install lhotse
# since it installs a version of PyTorch that is not predictable
git clone --depth 1 https://github.com/lhotse-speech/lhotse
cd lhotse
sed -i.bak "/torch/d" requirements.txt
pip install -r ./requirements.txt
# icefall requirements
pip install -r requirements.txt
- name: Run tests
if: startsWith(matrix.os, 'ubuntu')

View File

@ -1,67 +1,61 @@
# Table of Contents
- [Installation](#installation)
* [Install k2](#install-k2)
* [Install lhotse](#install-lhotse)
* [Install icefall](#install-icefall)
- [Run recipes](#run-recipes)
<div align="center">
<img src="https://raw.githubusercontent.com/k2-fsa/icefall/master/docs/source/_static/logo.png" width=168>
</div>
## Installation
`icefall` depends on [k2][k2] for FSA operations and [lhotse][lhotse] for
data preparations. To use `icefall`, you have to install its dependencies first.
The following subsections describe how to setup the environment.
Please refer to <https://icefall.readthedocs.io/en/latest/installation/index.html>
for installation.
CAUTION: There are various ways to setup the environment. What we describe
here is just one alternative.
## Recipes
### Install k2
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/index.html>
for more information.
Please refer to [k2's installation documentation][k2-install] to install k2.
If you have any issues about installing k2, please open an issue at
<https://github.com/k2-fsa/k2/issues>.
We provide two recipes at present:
### Install lhotse
- [yesno][yesno]
- [LibriSpeech][librispeech]
Please refer to [lhotse's installation documentation][lhotse-install] to install
lhotse.
### yesno
### Install icefall
This is the simplest ASR recipe in `icefall` and can be run on CPU.
Training takes less than 30 seconds and gives you the following WER:
`icefall` is a set of Python scripts. What you need to do is just to set
the environment variable `PYTHONPATH`:
```bash
cd $HOME/open-source
git clone https://github.com/k2-fsa/icefall
cd icefall
pip install -r requirements.txt
export PYTHONPATH=$HOME/open-source/icefall:$PYTHONPATHON
```
To verify `icefall` was installed successfully, you can run:
```bash
python3 -c "import icefall; print(icefall.__file__)"
[test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
```
We do provide a Colab notebook for this recipe.
It should print the path to `icefall`.
## Run recipes
At present, only LibriSpeech recipe is provided. Please
follow [egs/librispeech/ASR/README.md][LibriSpeech] to run it.
## Use Pre-trained models
See [egs/librispeech/ASR/conformer_ctc/README.md](egs/librispeech/ASR/conformer_ctc/README.md)
for how to use pre-trained models.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1huyupXAcHsUrKaWfI83iMEJ6J0Nh0213?usp=sharing)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing)
[LibriSpeech]: egs/librispeech/ASR/README.md
[k2-install]: https://k2.readthedocs.io/en/latest/installation/index.html#
[k2]: https://github.com/k2-fsa/k2
[lhotse]: https://github.com/lhotse-speech/lhotse
[lhotse-install]: https://lhotse.readthedocs.io/en/latest/getting-started.html#installation
### LibriSpeech
We provide two models for this recipe: [conformer CTC model][LibriSpeech_conformer_ctc]
and [TDNN LSTM CTC model][LibriSpeech_tdnn_lstm_ctc].
#### Conformer CTC Model
The best WER we currently have is:
||test-clean|test-other|
|--|--|--|
|WER| 2.57% | 5.94% |
We provide a Colab notebook to run a pre-trained conformer CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1huyupXAcHsUrKaWfI83iMEJ6J0Nh0213?usp=sharing)
#### TDNN LSTM CTC Model
The WER for this model is:
||test-clean|test-other|
|--|--|--|
|WER| 6.59% | 17.69% |
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing)
[LibriSpeech_tdnn_lstm_ctc]: egs/librispeech/ASR/tdnn_lstm_ctc
[LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc
[yesno]: egs/yesno/ASR
[librispeech]: egs/librispeech/ASR

1
docs/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
build/

20
docs/Makefile Normal file
View File

@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

35
docs/make.bat Normal file
View File

@ -0,0 +1,35 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd

2
docs/requirements.txt Normal file
View File

@ -0,0 +1,2 @@
sphinx_rtd_theme
sphinx

Binary file not shown.

After

Width:  |  Height:  |  Size: 666 KiB

77
docs/source/conf.py Normal file
View File

@ -0,0 +1,77 @@
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))
import sphinx_rtd_theme
# -- Project information -----------------------------------------------------
project = "icefall"
copyright = "2021, icefall development team"
author = "icefall development team"
# The full version, including alpha/beta/rc tags
release = "0.1"
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
"sphinx_rtd_theme",
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
source_suffix = {
".rst": "restructuredtext",
}
master_doc = "index"
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = "sphinx_rtd_theme"
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
html_show_sourcelink = True
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static", "installation/images"]
pygments_style = "sphinx"
numfig = True
html_context = {
"display_github": True,
"github_user": "k2-fsa",
"github_repo": "icefall",
"github_version": "master",
"conf_py_path": "/icefall/docs/source/",
}

View File

@ -0,0 +1,67 @@
.. _follow the code style:
Follow the code style
=====================
We use the following tools to make the code style to be as consistent as possible:
- `black <https://github.com/psf/black>`_, to format the code
- `flake8 <https://github.com/PyCQA/flake8>`_, to check the style and quality of the code
- `isort <https://github.com/PyCQA/isort>`_, to sort ``imports``
The following versions of the above tools are used:
- ``black == 12.6b0``
- ``flake8 == 3.9.2``
- ``isort == 5.9.2``
After running the following commands:
.. code-block::
$ git clone https://github.com/k2-fsa/icefall
$ cd icefall
$ pip install pre-commit
$ pre-commit install
it will run the following checks whenever you run ``git commit``, **automatically**:
.. figure:: images/pre-commit-check.png
:width: 600
:align: center
pre-commit hooks invoked by ``git commit`` (Failed).
If any of the above checks failed, your ``git commit`` was not successful.
Please fix any issues reported by the check tools.
.. HINT::
Some of the check tools, i.e., ``black`` and ``isort`` will modify
the files to be commited **in-place**. So please run ``git status``
after failure to see which file has been modified by the tools
before you make any further changes.
After fixing all the failures, run ``git commit`` again and
it should succeed this time:
.. figure:: images/pre-commit-check-success.png
:width: 600
:align: center
pre-commit hooks invoked by ``git commit`` (Succeeded).
If you want to check the style of your code before ``git commit``, you
can do the following:
.. code-block:: bash
$ cd icefall
$ pip install black==21.6b0 flake8==3.9.2 isort==5.9.2
$ black --check your_changed_file.py
$ black your_changed_file.py # modify it in-place
$
$ flake8 your_changed_file.py
$
$ isort --check your_changed_file.py # modify it in-place
$ isort your_changed_file.py

View File

@ -0,0 +1,45 @@
Contributing to Documentation
=============================
We use `sphinx <https://www.sphinx-doc.org/en/master/>`_
for documentation.
Before writing documentation, you have to prepare the environment:
.. code-block:: bash
$ cd docs
$ pip install -r requirements.txt
After setting up the environment, you are ready to write documentation.
Please refer to `reStructuredText Primer <https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html>`_
if you are not familiar with ``reStructuredText``.
After writing some documentation, you can build the documentation **locally**
to preview what it looks like if it is published:
.. code-block:: bash
$ cd docs
$ make html
The generated documentation is in ``docs/build/html`` and can be viewed
with the following commands:
.. code-block:: bash
$ cd docs/build/html
$ python3 -m http.server
It will print::
Serving HTTP on 0.0.0.0 port 8000 (http://0.0.0.0:8000/) ...
Open your browser, go to `<http://0.0.0.0:8000/>`_, and you will see
the following:
.. figure:: images/doc-contrib.png
:width: 600
:align: center
View generated documentation locally with ``python3 -m http.server``.

View File

@ -0,0 +1,156 @@
How to create a recipe
======================
.. HINT::
Please read :ref:`follow the code style` to adjust your code sytle.
.. CAUTION::
``icefall`` is designed to be as Pythonic as possible. Please use
Python in your recipe if possible.
Data Preparation
----------------
We recommend you to prepare your training/test/validate dataset
with `lhotse <https://github.com/lhotse-speech/lhotse>`_.
Please refer to `<https://lhotse.readthedocs.io/en/latest/index.html>`_
for how to create a recipe in ``lhotse``.
.. HINT::
The ``yesno`` recipe in ``lhotse`` is a very good example.
Please refer to `<https://github.com/lhotse-speech/lhotse/pull/380>`_,
which shows how to add a new recipe to ``lhotse``.
Suppose you would like to add a recipe for a dataset named ``foo``.
You can do the following:
.. code-block::
$ cd egs
$ mkdir -p foo/ASR
$ cd foo/ASR
$ touch prepare.sh
$ chmod +x prepare.sh
If your dataset is very simple, please follow
`egs/yesno/ASR/prepare.sh <https://github.com/k2-fsa/icefall/blob/master/egs/yesno/ASR/prepare.sh>`_
to write your own ``prepare.sh``.
Otherwise, please refer to
`egs/librispeech/ASR/prepare.sh <https://github.com/k2-fsa/icefall/blob/master/egs/yesno/ASR/prepare.sh>`_
to prepare your data.
Training
--------
Assume you have a fancy model, called ``bar`` for the ``foo`` recipe, you can
organize your files in the following way:
.. code-block::
$ cd egs/foo/ASR
$ mkdir bar
$ cd bar
$ touch README.md model.py train.py decode.py asr_datamodule.py pretrained.py
For instance , the ``yesno`` recipe has a ``tdnn`` model and its directory structure
looks like the following:
.. code-block:: bash
egs/yesno/ASR/tdnn/
|-- README.md
|-- asr_datamodule.py
|-- decode.py
|-- model.py
|-- pretrained.py
`-- train.py
**File description**:
- ``README.md``
It contains information of this recipe, e.g., how to run it, what the WER is, etc.
- ``asr_datamodule.py``
It provides code to create PyTorch dataloaders with train/test/validation dataset.
- ``decode.py``
It takes as inputs the checkpoints saved during the training stage to decode the test
dataset(s).
- ``model.py``
It contains the definition of your fancy neural network model.
- ``pretrained.py``
We can use this script to do inference with a pre-trained model.
- ``train.py``
It contains training code.
.. HINT::
Please take a look at
- `egs/yesno/tdnn <https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn>`_
- `egs/librispeech/tdnn_lstm_ctc <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/tdnn_lstm_ctc>`_
- `egs/librispeech/conformer_ctc <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conformer_ctc>`_
to get a feel what the resulting files look like.
.. NOTE::
Every model in a recipe is kept to be as self-contained as possible.
We tolerate duplicate code among different recipes.
The training stage should be invocable by:
.. code-block::
$ cd egs/foo/ASR
$ ./bar/train.py
$ ./bar/train.py --help
Decoding
--------
Please refer to
- `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/conformer_ctc/decode.py>`_
If your model is transformer/conformer based.
- `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py>`_
If your model is TDNN/LSTM based, i.e., there is no attention decoder.
- `<https://github.com/k2-fsa/icefall/blob/master/egs/yesno/ASR/tdnn/decode.py>`_
If there is no LM rescoring.
The decoding stage should be invocable by:
.. code-block::
$ cd egs/foo/ASR
$ ./bar/decode.py
$ ./bar/decode.py --help
Pre-trained model
-----------------
Please demonstrate how to use your model for inference in ``egs/foo/ASR/bar/pretrained.py``.
If possible, please consider creating a Colab notebook to show that.

Binary file not shown.

After

Width:  |  Height:  |  Size: 198 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 153 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 214 KiB

View File

@ -0,0 +1,22 @@
Contributing
============
Contributions to ``icefall`` are very welcomed.
There are many possible ways to make contributions and
two of them are:
- To write documentation
- To write code
- (1) To follow the code style in the repository
- (2) To write a new recipe
In this page, we describe how to contribute documentation
and code to ``icefall``.
.. toctree::
:maxdepth: 2
doc
code-style
how-to-create-a-recipe

25
docs/source/index.rst Normal file
View File

@ -0,0 +1,25 @@
.. icefall documentation master file, created by
sphinx-quickstart on Mon Aug 23 16:07:39 2021.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Icefall
=======
.. image:: _static/logo.png
:alt: icefall logo
:width: 168px
:align: center
:target: https://github.com/k2-fsa/icefall
Documentation for `icefall <https://github.com/k2-fsa/icefall>`_, containing
speech recognition recipes using `k2 <https://github.com/k2-fsa/k2>`_.
.. toctree::
:maxdepth: 2
:caption: Contents:
installation/index
recipes/index
contributing/index

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="122" height="20" role="img" aria-label="device: CPU | CUDA"><title>device: CPU | CUDA</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="122" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="45" height="20" fill="#555"/><rect x="45" width="77" height="20" fill="#fe7d37"/><rect width="122" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="235" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="350">device</text><text x="235" y="140" transform="scale(.1)" fill="#fff" textLength="350">device</text><text aria-hidden="true" x="825" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="670">CPU | CUDA</text><text x="825" y="140" transform="scale(.1)" fill="#fff" textLength="670">CPU | CUDA</text></g></svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="114" height="20" role="img" aria-label="os: Linux | macOS"><title>os: Linux | macOS</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="114" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="23" height="20" fill="#555"/><rect x="23" width="91" height="20" fill="#ff69b4"/><rect width="114" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="125" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="130">os</text><text x="125" y="140" transform="scale(.1)" fill="#fff" textLength="130">os</text><text aria-hidden="true" x="675" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="810">Linux | macOS</text><text x="675" y="140" transform="scale(.1)" fill="#fff" textLength="810">Linux | macOS</text></g></svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="170" height="20" role="img" aria-label="python: 3.6 | 3.7 | 3.8 | 3.9"><title>python: 3.6 | 3.7 | 3.8 | 3.9</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="170" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="49" height="20" fill="#555"/><rect x="49" width="121" height="20" fill="#007ec6"/><rect width="170" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="255" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="390">python</text><text x="255" y="140" transform="scale(.1)" fill="#fff" textLength="390">python</text><text aria-hidden="true" x="1085" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="1110">3.6 | 3.7 | 3.8 | 3.9</text><text x="1085" y="140" transform="scale(.1)" fill="#fff" textLength="1110">3.6 | 3.7 | 3.8 | 3.9</text></g></svg>

After

Width:  |  Height:  |  Size: 1.2 KiB

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="286" height="20" role="img" aria-label="torch: 1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0"><title>torch: 1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="286" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="39" height="20" fill="#555"/><rect x="39" width="247" height="20" fill="#97ca00"/><rect width="286" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="205" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="290">torch</text><text x="205" y="140" transform="scale(.1)" fill="#fff" textLength="290">torch</text><text aria-hidden="true" x="1615" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="2370">1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0</text><text x="1615" y="140" transform="scale(.1)" fill="#fff" textLength="2370">1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0</text></g></svg>

After

Width:  |  Height:  |  Size: 1.3 KiB

View File

@ -0,0 +1,453 @@
.. _install icefall:
Installation
============
- |os|
- |device|
- |python_versions|
- |torch_versions|
.. |os| image:: ./images/os-Linux_macOS-ff69b4.svg
:alt: Supported operating systems
.. |device| image:: ./images/device-CPU_CUDA-orange.svg
:alt: Supported devices
.. |python_versions| image:: ./images/python-3.6_3.7_3.8_3.9-blue.svg
:alt: Supported python versions
.. |torch_versions| image:: ./images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg
:alt: Supported PyTorch versions
icefall depends on `k2 <https://github.com/k2-fsa/k2>`_ and
`lhotse <https://github.com/lhotse-speech/lhotse>`_.
We recommend you to install ``k2`` first, as ``k2`` is bound to
a specific version of PyTorch after compilation. Install ``k2`` also
installs its dependency PyTorch, which can be reused by ``lhotse``.
(1) Install k2
--------------
Please refer to `<https://k2.readthedocs.io/en/latest/installation/index.html>`_
to install `k2`.
.. HINT::
If you have already installed PyTorch and don't want to replace it,
please install a version of k2 that is compiled against the version
of PyTorch you are using.
(2) Install lhotse
------------------
Please refer to `<https://lhotse.readthedocs.io/en/latest/getting-started.html#installation>`_
to install ``lhotse``.
.. HINT::
Install ``lhotse`` also installs its dependency `torchaudio <https://github.com/pytorch/audio>`_.
(3) Download icefall
--------------------
icefall is a collection of Python scripts, so you don't need to install it
and we don't provide a ``setup.py`` to install it.
What you need is to download it and set the environment variable ``PYTHONPATH``
to point to it.
Assume you want to place ``icefall`` in the folder ``/tmp``. The
following commands show you how to setup ``icefall``:
.. code-block:: bash
cd /tmp
git clone https://github.com/k2-fsa/icefall
cd icefall
pip install -r requirements.txt
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
.. HINT::
You can put several versions of ``icefall`` in the same virtual environment.
To switch among different versions of ``icefall``, just set ``PYTHONPATH``
to point to the version you want.
Installation example
--------------------
The following shows an example about setting up the environment.
(1) Create a virtual environment
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
$ virtualenv -p python3.8 test-icefall
created virtual environment CPython3.8.6.final.0-64 in 1540ms
creator CPython3Posix(dest=/ceph-fj/fangjun/test-icefall, clear=False, no_vcs_ignore=False, global=False)
seeder FromAppData(download=False, pip=bundle, setuptools=bundle, wheel=bundle, via=copy, app_data_dir=/root/fangjun/.local/share/v
irtualenv)
added seed packages: pip==21.1.3, setuptools==57.4.0, wheel==0.36.2
activators BashActivator,CShellActivator,FishActivator,PowerShellActivator,PythonActivator,XonshActivator
(2) Activate your virtual environment
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
$ source test-icefall/bin/activate
(3) Install k2
~~~~~~~~~~~~~~
.. code-block:: bash
$ pip install k2==1.4.dev20210822+cpu.torch1.9.0 -f https://k2-fsa.org/nightly/index.html
Looking in links: https://k2-fsa.org/nightly/index.html
Collecting k2==1.4.dev20210822+cpu.torch1.9.0
Downloading https://k2-fsa.org/nightly/whl/k2-1.4.dev20210822%2Bcpu.torch1.9.0-cp38-cp38-linux_x86_64.whl (1.6 MB)
|________________________________| 1.6 MB 185 kB/s
Collecting graphviz
Downloading graphviz-0.17-py3-none-any.whl (18 kB)
Collecting torch==1.9.0
Using cached torch-1.9.0-cp38-cp38-manylinux1_x86_64.whl (831.4 MB)
Collecting typing-extensions
Using cached typing_extensions-3.10.0.0-py3-none-any.whl (26 kB)
Installing collected packages: typing-extensions, torch, graphviz, k2
Successfully installed graphviz-0.17 k2-1.4.dev20210822+cpu.torch1.9.0 torch-1.9.0 typing-extensions-3.10.0.0
.. WARNING::
We choose to install a CPU version of k2 for testing. You would probably want to install
a CUDA version of k2.
(4) Install lhotse
~~~~~~~~~~~~~~~~~~
.. code-block::
$ pip install git+https://github.com/lhotse-speech/lhotse
Collecting git+https://github.com/lhotse-speech/lhotse
Cloning https://github.com/lhotse-speech/lhotse to /tmp/pip-req-build-7b1b76ge
Running command git clone -q https://github.com/lhotse-speech/lhotse /tmp/pip-req-build-7b1b76ge
Collecting audioread>=2.1.9
Using cached audioread-2.1.9-py3-none-any.whl
Collecting SoundFile>=0.10
Using cached SoundFile-0.10.3.post1-py2.py3-none-any.whl (21 kB)
Collecting click>=7.1.1
Using cached click-8.0.1-py3-none-any.whl (97 kB)
Collecting cytoolz>=0.10.1
Using cached cytoolz-0.11.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.9 MB)
Collecting dataclasses
Using cached dataclasses-0.6-py3-none-any.whl (14 kB)
Collecting h5py>=2.10.0
Downloading h5py-3.4.0-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (4.5 MB)
|________________________________| 4.5 MB 684 kB/s
Collecting intervaltree>=3.1.0
Using cached intervaltree-3.1.0-py2.py3-none-any.whl
Collecting lilcom>=1.1.0
Using cached lilcom-1.1.1-cp38-cp38-linux_x86_64.whl
Collecting numpy>=1.18.1
Using cached numpy-1.21.2-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (15.8 MB)
Collecting packaging
Using cached packaging-21.0-py3-none-any.whl (40 kB)
Collecting pyyaml>=5.3.1
Using cached PyYAML-5.4.1-cp38-cp38-manylinux1_x86_64.whl (662 kB)
Collecting tqdm
Downloading tqdm-4.62.1-py2.py3-none-any.whl (76 kB)
|________________________________| 76 kB 2.7 MB/s
Collecting torchaudio==0.9.0
Downloading torchaudio-0.9.0-cp38-cp38-manylinux1_x86_64.whl (1.9 MB)
|________________________________| 1.9 MB 73.1 MB/s
Requirement already satisfied: torch==1.9.0 in ./test-icefall/lib/python3.8/site-packages (from torchaudio==0.9.0->lhotse===0.8.0.dev
-2a1410b-clean) (1.9.0)
Requirement already satisfied: typing-extensions in ./test-icefall/lib/python3.8/site-packages (from torch==1.9.0->torchaudio==0.9.0-
>lhotse===0.8.0.dev-2a1410b-clean) (3.10.0.0)
Collecting toolz>=0.8.0
Using cached toolz-0.11.1-py3-none-any.whl (55 kB)
Collecting sortedcontainers<3.0,>=2.0
Using cached sortedcontainers-2.4.0-py2.py3-none-any.whl (29 kB)
Collecting cffi>=1.0
Using cached cffi-1.14.6-cp38-cp38-manylinux1_x86_64.whl (411 kB)
Collecting pycparser
Using cached pycparser-2.20-py2.py3-none-any.whl (112 kB)
Collecting pyparsing>=2.0.2
Using cached pyparsing-2.4.7-py2.py3-none-any.whl (67 kB)
Building wheels for collected packages: lhotse
Building wheel for lhotse (setup.py) ... done
Created wheel for lhotse: filename=lhotse-0.8.0.dev_2a1410b_clean-py3-none-any.whl size=342242 sha256=f683444afa4dc0881133206b4646a
9d0f774224cc84000f55d0a67f6e4a37997
Stored in directory: /tmp/pip-ephem-wheel-cache-ftu0qysz/wheels/7f/7a/8e/a0bf241336e2e3cb573e1e21e5600952d49f5162454f2e612f
WARNING: Built wheel for lhotse is invalid: Metadata 1.2 mandates PEP 440 version, but '0.8.0.dev-2a1410b-clean' is not
Failed to build lhotse
Installing collected packages: pycparser, toolz, sortedcontainers, pyparsing, numpy, cffi, tqdm, torchaudio, SoundFile, pyyaml, packa
ging, lilcom, intervaltree, h5py, dataclasses, cytoolz, click, audioread, lhotse
Running setup.py install for lhotse ... done
DEPRECATION: lhotse was installed using the legacy 'setup.py install' method, because a wheel could not be built for it. A possible
replacement is to fix the wheel build issue reported above. You can find discussion regarding this at https://github.com/pypa/pip/is
sues/8368.
Successfully installed SoundFile-0.10.3.post1 audioread-2.1.9 cffi-1.14.6 click-8.0.1 cytoolz-0.11.0 dataclasses-0.6 h5py-3.4.0 inter
valtree-3.1.0 lhotse-0.8.0.dev-2a1410b-clean lilcom-1.1.1 numpy-1.21.2 packaging-21.0 pycparser-2.20 pyparsing-2.4.7 pyyaml-5.4.1 sor
tedcontainers-2.4.0 toolz-0.11.1 torchaudio-0.9.0 tqdm-4.62.1
(5) Download icefall
~~~~~~~~~~~~~~~~~~~~
.. code-block::
$ cd /tmp
$ git clone https://github.com/k2-fsa/icefall
Cloning into 'icefall'...
remote: Enumerating objects: 500, done.
remote: Counting objects: 100% (500/500), done.
remote: Compressing objects: 100% (308/308), done.
remote: Total 500 (delta 263), reused 307 (delta 102), pack-reused 0
Receiving objects: 100% (500/500), 172.49 KiB | 385.00 KiB/s, done.
Resolving deltas: 100% (263/263), done.
$ cd icefall
$ pip install -r requirements.txt
Collecting kaldilm
Downloading kaldilm-1.8.tar.gz (48 kB)
|________________________________| 48 kB 574 kB/s
Collecting kaldialign
Using cached kaldialign-0.2-cp38-cp38-linux_x86_64.whl
Collecting sentencepiece>=0.1.96
Using cached sentencepiece-0.1.96-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
Collecting tensorboard
Using cached tensorboard-2.6.0-py3-none-any.whl (5.6 MB)
Requirement already satisfied: setuptools>=41.0.0 in /ceph-fj/fangjun/test-icefall/lib/python3.8/site-packages (from tensorboard->-r
requirements.txt (line 4)) (57.4.0)
Collecting absl-py>=0.4
Using cached absl_py-0.13.0-py3-none-any.whl (132 kB)
Collecting google-auth-oauthlib<0.5,>=0.4.1
Using cached google_auth_oauthlib-0.4.5-py2.py3-none-any.whl (18 kB)
Collecting grpcio>=1.24.3
Using cached grpcio-1.39.0-cp38-cp38-manylinux2014_x86_64.whl (4.3 MB)
Requirement already satisfied: wheel>=0.26 in /ceph-fj/fangjun/test-icefall/lib/python3.8/site-packages (from tensorboard->-r require
ments.txt (line 4)) (0.36.2)
Requirement already satisfied: numpy>=1.12.0 in /ceph-fj/fangjun/test-icefall/lib/python3.8/site-packages (from tensorboard->-r requi
rements.txt (line 4)) (1.21.2)
Collecting protobuf>=3.6.0
Using cached protobuf-3.17.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.0 MB)
Collecting werkzeug>=0.11.15
Using cached Werkzeug-2.0.1-py3-none-any.whl (288 kB)
Collecting tensorboard-data-server<0.7.0,>=0.6.0
Using cached tensorboard_data_server-0.6.1-py3-none-manylinux2010_x86_64.whl (4.9 MB)
Collecting google-auth<2,>=1.6.3
Downloading google_auth-1.35.0-py2.py3-none-any.whl (152 kB)
|________________________________| 152 kB 1.4 MB/s
Collecting requests<3,>=2.21.0
Using cached requests-2.26.0-py2.py3-none-any.whl (62 kB)
Collecting tensorboard-plugin-wit>=1.6.0
Using cached tensorboard_plugin_wit-1.8.0-py3-none-any.whl (781 kB)
Collecting markdown>=2.6.8
Using cached Markdown-3.3.4-py3-none-any.whl (97 kB)
Collecting six
Using cached six-1.16.0-py2.py3-none-any.whl (11 kB)
Collecting cachetools<5.0,>=2.0.0
Using cached cachetools-4.2.2-py3-none-any.whl (11 kB)
Collecting rsa<5,>=3.1.4
Using cached rsa-4.7.2-py3-none-any.whl (34 kB)
Collecting pyasn1-modules>=0.2.1
Using cached pyasn1_modules-0.2.8-py2.py3-none-any.whl (155 kB)
Collecting requests-oauthlib>=0.7.0
Using cached requests_oauthlib-1.3.0-py2.py3-none-any.whl (23 kB)
Collecting pyasn1<0.5.0,>=0.4.6
Using cached pyasn1-0.4.8-py2.py3-none-any.whl (77 kB)
Collecting urllib3<1.27,>=1.21.1
Using cached urllib3-1.26.6-py2.py3-none-any.whl (138 kB)
Collecting certifi>=2017.4.17
Using cached certifi-2021.5.30-py2.py3-none-any.whl (145 kB)
Collecting charset-normalizer~=2.0.0
Using cached charset_normalizer-2.0.4-py3-none-any.whl (36 kB)
Collecting idna<4,>=2.5
Using cached idna-3.2-py3-none-any.whl (59 kB)
Collecting oauthlib>=3.0.0
Using cached oauthlib-3.1.1-py2.py3-none-any.whl (146 kB)
Building wheels for collected packages: kaldilm
Building wheel for kaldilm (setup.py) ... done
Created wheel for kaldilm: filename=kaldilm-1.8-cp38-cp38-linux_x86_64.whl size=897233 sha256=eccb906cafcd45bf9a7e1a1718e4534254bfb
f4c0d0cbc66eee6c88d68a63862
Stored in directory: /root/fangjun/.cache/pip/wheels/85/7d/63/f2dd586369b8797cb36d213bf3a84a789eeb92db93d2e723c9
Successfully built kaldilm
Installing collected packages: urllib3, pyasn1, idna, charset-normalizer, certifi, six, rsa, requests, pyasn1-modules, oauthlib, cach
etools, requests-oauthlib, google-auth, werkzeug, tensorboard-plugin-wit, tensorboard-data-server, protobuf, markdown, grpcio, google
-auth-oauthlib, absl-py, tensorboard, sentencepiece, kaldilm, kaldialign
Successfully installed absl-py-0.13.0 cachetools-4.2.2 certifi-2021.5.30 charset-normalizer-2.0.4 google-auth-1.35.0 google-auth-oaut
hlib-0.4.5 grpcio-1.39.0 idna-3.2 kaldialign-0.2 kaldilm-1.8 markdown-3.3.4 oauthlib-3.1.1 protobuf-3.17.3 pyasn1-0.4.8 pyasn1-module
s-0.2.8 requests-2.26.0 requests-oauthlib-1.3.0 rsa-4.7.2 sentencepiece-0.1.96 six-1.16.0 tensorboard-2.6.0 tensorboard-data-server-0
.6.1 tensorboard-plugin-wit-1.8.0 urllib3-1.26.6 werkzeug-2.0.1
Test Your Installation
----------------------
To test that your installation is successful, let us run
the `yesno recipe <https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR>`_
on CPU.
Data preparation
~~~~~~~~~~~~~~~~
.. code-block:: bash
$ export PYTHONPATH=/tmp/icefall:$PYTHONPATH
$ cd /tmp/icefall
$ cd egs/yesno/ASR
$ ./prepare.sh
The log of running ``./prepare.sh`` is:
.. code-block::
2021-08-23 19:27:26 (prepare.sh:24:main) dl_dir: /tmp/icefall/egs/yesno/ASR/download
2021-08-23 19:27:26 (prepare.sh:27:main) stage 0: Download data
Downloading waves_yesno.tar.gz: 4.49MB [00:03, 1.39MB/s]
2021-08-23 19:27:30 (prepare.sh:36:main) Stage 1: Prepare yesno manifest
2021-08-23 19:27:31 (prepare.sh:42:main) Stage 2: Compute fbank for yesno
2021-08-23 19:27:32,803 INFO [compute_fbank_yesno.py:52] Processing train
Extracting and storing features: 100%|_______________________________________________________________| 90/90 [00:01<00:00, 80.57it/s]
2021-08-23 19:27:34,085 INFO [compute_fbank_yesno.py:52] Processing test
Extracting and storing features: 100%|______________________________________________________________| 30/30 [00:00<00:00, 248.21it/s]
2021-08-23 19:27:34 (prepare.sh:48:main) Stage 3: Prepare lang
2021-08-23 19:27:35 (prepare.sh:63:main) Stage 4: Prepare G
/tmp/pip-install-fcordre9/kaldilm_6899d26f2d684ad48f21025950cd2866/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Rea
d(std::istream&):79
[I] Reading \data\ section.
/tmp/pip-install-fcordre9/kaldilm_6899d26f2d684ad48f21025950cd2866/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Rea
d(std::istream&):140
[I] Reading \1-grams: section.
2021-08-23 19:27:35 (prepare.sh:89:main) Stage 5: Compile HLG
2021-08-23 19:27:35,928 INFO [compile_hlg.py:120] Processing data/lang_phone
2021-08-23 19:27:35,929 INFO [lexicon.py:116] Converting L.pt to Linv.pt
2021-08-23 19:27:35,931 INFO [compile_hlg.py:48] Building ctc_topo. max_token_id: 3
2021-08-23 19:27:35,932 INFO [compile_hlg.py:52] Loading G.fst.txt
2021-08-23 19:27:35,932 INFO [compile_hlg.py:62] Intersecting L and G
2021-08-23 19:27:35,933 INFO [compile_hlg.py:64] LG shape: (4, None)
2021-08-23 19:27:35,933 INFO [compile_hlg.py:66] Connecting LG
2021-08-23 19:27:35,933 INFO [compile_hlg.py:68] LG shape after k2.connect: (4, None)
2021-08-23 19:27:35,933 INFO [compile_hlg.py:70] <class 'torch.Tensor'>
2021-08-23 19:27:35,933 INFO [compile_hlg.py:71] Determinizing LG
2021-08-23 19:27:35,934 INFO [compile_hlg.py:74] <class '_k2.RaggedInt'>
2021-08-23 19:27:35,934 INFO [compile_hlg.py:76] Connecting LG after k2.determinize
2021-08-23 19:27:35,934 INFO [compile_hlg.py:79] Removing disambiguation symbols on LG
2021-08-23 19:27:35,934 INFO [compile_hlg.py:87] LG shape after k2.remove_epsilon: (6, None)
2021-08-23 19:27:35,935 INFO [compile_hlg.py:92] Arc sorting LG
2021-08-23 19:27:35,935 INFO [compile_hlg.py:95] Composing H and LG
2021-08-23 19:27:35,935 INFO [compile_hlg.py:102] Connecting LG
2021-08-23 19:27:35,935 INFO [compile_hlg.py:105] Arc sorting LG
2021-08-23 19:27:35,936 INFO [compile_hlg.py:107] HLG.shape: (8, None)
2021-08-23 19:27:35,936 INFO [compile_hlg.py:123] Saving HLG.pt to data/lang_phone
Training
~~~~~~~~
Now let us run the training part:
.. code-block::
$ export CUDA_VISIBLE_DEVICES=""
$ ./tdnn/train.py
.. CAUTION::
We use ``export CUDA_VISIBLE_DEVICES=""`` so that icefall uses CPU
even if there are GPUs available.
The training log is given below:
.. code-block::
2021-08-23 19:30:31,072 INFO [train.py:465] Training started
2021-08-23 19:30:31,072 INFO [train.py:466] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01,
'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, '
best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_doub
le_scores': True, 'world_size': 1, 'master_port': 12354, 'tensorboard': True, 'num_epochs': 15, 'feature_dir': PosixPath('data/fbank'
), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0
, 'on_the_fly_feats': False, 'shuffle': True, 'return_cuts': True, 'num_workers': 2}
2021-08-23 19:30:31,074 INFO [lexicon.py:113] Loading pre-compiled data/lang_phone/Linv.pt
2021-08-23 19:30:31,098 INFO [asr_datamodule.py:146] About to get train cuts
2021-08-23 19:30:31,098 INFO [asr_datamodule.py:240] About to get train cuts
2021-08-23 19:30:31,102 INFO [asr_datamodule.py:149] About to create train dataset
2021-08-23 19:30:31,102 INFO [asr_datamodule.py:200] Using SingleCutSampler.
2021-08-23 19:30:31,102 INFO [asr_datamodule.py:206] About to create train dataloader
2021-08-23 19:30:31,102 INFO [asr_datamodule.py:219] About to get test cuts
2021-08-23 19:30:31,102 INFO [asr_datamodule.py:246] About to get test cuts
2021-08-23 19:30:31,357 INFO [train.py:416] Epoch 0, batch 0, batch avg loss 1.0789, total avg loss: 1.0789, batch size: 4
2021-08-23 19:30:31,848 INFO [train.py:416] Epoch 0, batch 10, batch avg loss 0.5356, total avg loss: 0.7556, batch size: 4
2021-08-23 19:30:32,301 INFO [train.py:432] Epoch 0, valid loss 0.9972, best valid loss: 0.9972 best valid epoch: 0
2021-08-23 19:30:32,805 INFO [train.py:416] Epoch 0, batch 20, batch avg loss 0.2436, total avg loss: 0.5717, batch size: 3
2021-08-23 19:30:33,109 INFO [train.py:432] Epoch 0, valid loss 0.4167, best valid loss: 0.4167 best valid epoch: 0
2021-08-23 19:30:33,121 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-0.pt
2021-08-23 19:30:33,325 INFO [train.py:416] Epoch 1, batch 0, batch avg loss 0.2214, total avg loss: 0.2214, batch size: 5
2021-08-23 19:30:33,798 INFO [train.py:416] Epoch 1, batch 10, batch avg loss 0.0781, total avg loss: 0.1343, batch size: 5
2021-08-23 19:30:34,065 INFO [train.py:432] Epoch 1, valid loss 0.0859, best valid loss: 0.0859 best valid epoch: 1
2021-08-23 19:30:34,556 INFO [train.py:416] Epoch 1, batch 20, batch avg loss 0.0421, total avg loss: 0.0975, batch size: 3
2021-08-23 19:30:34,810 INFO [train.py:432] Epoch 1, valid loss 0.0431, best valid loss: 0.0431 best valid epoch: 1
2021-08-23 19:30:34,824 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-1.pt
... ...
2021-08-23 19:30:49,657 INFO [train.py:416] Epoch 13, batch 0, batch avg loss 0.0109, total avg loss: 0.0109, batch size: 5
2021-08-23 19:30:49,984 INFO [train.py:416] Epoch 13, batch 10, batch avg loss 0.0093, total avg loss: 0.0096, batch size: 4
2021-08-23 19:30:50,239 INFO [train.py:432] Epoch 13, valid loss 0.0104, best valid loss: 0.0101 best valid epoch: 12
2021-08-23 19:30:50,569 INFO [train.py:416] Epoch 13, batch 20, batch avg loss 0.0092, total avg loss: 0.0096, batch size: 2
2021-08-23 19:30:50,819 INFO [train.py:432] Epoch 13, valid loss 0.0101, best valid loss: 0.0101 best valid epoch: 13
2021-08-23 19:30:50,835 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-13.pt
2021-08-23 19:30:51,024 INFO [train.py:416] Epoch 14, batch 0, batch avg loss 0.0105, total avg loss: 0.0105, batch size: 5
2021-08-23 19:30:51,317 INFO [train.py:416] Epoch 14, batch 10, batch avg loss 0.0099, total avg loss: 0.0097, batch size: 4
2021-08-23 19:30:51,552 INFO [train.py:432] Epoch 14, valid loss 0.0108, best valid loss: 0.0101 best valid epoch: 13
2021-08-23 19:30:51,869 INFO [train.py:416] Epoch 14, batch 20, batch avg loss 0.0096, total avg loss: 0.0097, batch size: 5
2021-08-23 19:30:52,107 INFO [train.py:432] Epoch 14, valid loss 0.0102, best valid loss: 0.0101 best valid epoch: 13
2021-08-23 19:30:52,126 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-14.pt
2021-08-23 19:30:52,128 INFO [train.py:537] Done!
Decoding
~~~~~~~~
Let us use the trained model to decode the test set:
.. code-block::
$ ./tdnn/decode.py
The decoding log is:
.. code-block::
2021-08-23 19:35:30,192 INFO [decode.py:249] Decoding started
2021-08-23 19:35:30,192 INFO [decode.py:250] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': True, 'return_cuts': True, 'num_workers': 2}
2021-08-23 19:35:30,193 INFO [lexicon.py:113] Loading pre-compiled data/lang_phone/Linv.pt
2021-08-23 19:35:30,213 INFO [decode.py:259] device: cpu
2021-08-23 19:35:30,217 INFO [decode.py:279] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
/tmp/icefall/icefall/checkpoint.py:146: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch.
It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at /pytorch/aten/src/ATen/native/BinaryOps.cpp:450.)
avg[k] //= n
2021-08-23 19:35:30,220 INFO [asr_datamodule.py:219] About to get test cuts
2021-08-23 19:35:30,220 INFO [asr_datamodule.py:246] About to get test cuts
2021-08-23 19:35:30,409 INFO [decode.py:190] batch 0/8, cuts processed until now is 4
2021-08-23 19:35:30,571 INFO [decode.py:228] The transcripts are stored in tdnn/exp/recogs-test_set.txt
2021-08-23 19:35:30,572 INFO [utils.py:317] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
2021-08-23 19:35:30,573 INFO [decode.py:236] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
2021-08-23 19:35:30,573 INFO [decode.py:299] Done!
**Congratulations!** You have successfully setup the environment and have run the first recipe in ``icefall``.
Have fun with ``icefall``!

Binary file not shown.

After

Width:  |  Height:  |  Size: 121 KiB

View File

@ -0,0 +1,18 @@
Recipes
=======
This page contains various recipes in ``icefall``.
Currently, only speech recognition recipes are provided.
We may add recipes for other tasks as well in the future.
.. we put the yesno recipe as the first recipe since it is the simplest one.
.. Other recipes are listed in a alphabetical order.
.. toctree::
:maxdepth: 2
yesno
librispeech

View File

@ -0,0 +1,10 @@
LibriSpeech
===========
We provide the following models for the LibriSpeech dataset:
.. toctree::
:maxdepth: 2
librispeech/tdnn_lstm_ctc
librispeech/conformer_ctc

View File

@ -0,0 +1,627 @@
Confromer CTC
=============
This tutorial shows you how to run a conformer ctc model
with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
.. HINT::
We assume you have read the page :ref:`install icefall` and have setup
the environment for ``icefall``.
.. HINT::
We recommend you to use a GPU or several GPUs to run this recipe.
In this tutorial, you will learn:
- (1) How to prepare data for training and decoding
- (2) How to start the training, either with a single GPU or multiple GPUs
- (3) How to do decoding after training, with n-gram LM rescoring and attention decoder rescoring
- (4) How to use a pre-trained model, provided by us
Data preparation
----------------
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./prepare.sh
The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
All you need to do is to run it.
The data preparation contains several stages, you can use the following two
options:
- ``--stage``
- ``--stop-stage``
to control which stage(s) should be run. By default, all stages are executed.
For example,
.. code-block:: bash
$ cd egs/yesno/ASR
$ ./prepare.sh --stage 0 --stop-stage 0
means to run only stage 0.
To run stage 2 to stage 5, use:
.. code-block:: bash
$ ./prepare.sh --stage 2 --stop-stage 5
.. HINT::
If you have pre-downloaded the `LibriSpeech <https://www.openslr.org/12>`_
dataset and the `musan <http://www.openslr.org/17/>`_ dataset, say,
they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify
the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that
``./prepare.sh`` won't re-download them.
.. NOTE::
All generated files by ``./prepare.sh``, e.g., features, lexicon, etc,
are saved in ``./data`` directory.
Training
--------
Configurable options
~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./conformer_ctc/train.py --help
shows you the training options that can be passed from the commandline.
The following options are used quite often:
- ``--full-libri``
If it's True, the training part uses all the training data, i.e.,
960 hours. Otherwise, the training part uses only the subset
``train-clean-100``, which has 100 hours of training data.
.. CAUTION::
The training set is perturbed by speed with two factors: 0.9 and 1.1.
If ``--full-libri`` is True, each epoch actually processes
``3x960 == 2880`` hours of data.
- ``--num-epochs``
It is the number of epochs to train. For instance,
``./conformer_ctc/train.py --num-epochs 30`` trains for 30 epochs
and generates ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-29.pt``
in the folder ``./conformer_ctc/exp``.
- ``--start-epoch``
It's used to resume training.
``./conformer_ctc/train.py --start-epoch 10`` loads the
checkpoint ``./conformer_ctc/exp/epoch-9.pt`` and starts
training from epoch 10, based on the state from epoch 9.
- ``--world-size``
It is used for multi-GPU single-machine DDP training.
- (a) If it is 1, then no DDP training is used.
- (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training.
The following shows some use cases with it.
**Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and
GPU 2 for training. You can do the following:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ export CUDA_VISIBLE_DEVICES="0,2"
$ ./conformer_ctc/train.py --world-size 2
**Use case 2**: You have 4 GPUs and you want to use all of them
for training. You can do the following:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./conformer_ctc/train.py --world-size 4
**Use case 3**: You have 4 GPUs but you only want to use GPU 3
for training. You can do the following:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ export CUDA_VISIBLE_DEVICES="3"
$ ./conformer_ctc/train.py --world-size 1
.. CAUTION::
Only multi-GPU single-machine DDP training is implemented at present.
Multi-GPU multi-machine DDP training will be added later.
- ``--max-duration``
It specifies the number of seconds over all utterances in a
batch, before **padding**.
If you encounter CUDA OOM, please reduce it. For instance, if
your are using V100 NVIDIA GPU, we recommend you to set it to ``200``.
.. HINT::
Due to padding, the number of seconds of all utterances in a
batch will usually be larger than ``--max-duration``.
A larger value for ``--max-duration`` may cause OOM during training,
while a smaller value may increase the training time. You have to
tune it.
Pre-configured options
~~~~~~~~~~~~~~~~~~~~~~
There are some training options, e.g., learning rate,
number of warmup steps, results dir, etc,
that are not passed from the commandline.
They are pre-configured by the function ``get_params()`` in
`conformer_ctc/train.py <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/conformer_ctc/train.py>`_
You don't need to change these pre-configured parameters. If you really need to change
them, please modify ``./conformer_ctc/train.py`` directly.
Training logs
~~~~~~~~~~~~~
Training logs and checkpoints are saved in ``conformer_ctc/exp``.
You will find the following files in that directory:
- ``epoch-0.pt``, ``epoch-1.pt``, ...
These are checkpoint files, containing model ``state_dict`` and optimizer ``state_dict``.
To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
.. code-block:: bash
$ ./conformer_ctc/train.py --start-epoch 11
- ``tensorboard/``
This folder contains TensorBoard logs. Training loss, validation loss, learning
rate, etc, are recorded in these logs. You can visualize them by:
.. code-block:: bash
$ cd conformer_ctc/exp/tensorboard
$ tensorboard dev upload --logdir . --description "Conformer CTC training for LibriSpeech with icefall"
It will print something like below:
.. code-block::
TensorFlow installation not found - running with reduced feature set.
Upload started and will continue reading any new data as it's added to the logdir.
To stop uploading, press Ctrl-C.
New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/lzGnETjwRxC3yghNMd4kPw/
[2021-08-24T16:42:43] Started scanning logdir.
Uploading 4540 scalars...
Note there is a URL in the above output, click it and you will see
the following screenshot:
.. figure:: images/librispeech-conformer-ctc-tensorboard-log.png
:width: 600
:alt: TensorBoard screenshot
:align: center
:target: https://tensorboard.dev/experiment/lzGnETjwRxC3yghNMd4kPw/
TensorBoard screenshot.
- ``log/log-train-xxxx``
It is the detailed training log in text format, same as the one
you saw printed to the console during training.
Usage examples
~~~~~~~~~~~~~~
The following shows typical use cases:
**Case 1**
^^^^^^^^^^
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./conformer_ctc/train.py --max-duration 200 --full-libri 0
It uses ``--max-duration`` of 200 to avoid OOM. Also, it uses only
a subset of the LibriSpeech data for training.
**Case 2**
^^^^^^^^^^
.. code-block:: bash
$ cd egs/librispeech/ASR
$ export CUDA_VISIBLE_DEVICES="0,3"
$ ./conformer_ctc/train.py --world-size 2
It uses GPU 0 and GPU 3 for DDP training.
**Case 3**
^^^^^^^^^^
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./conformer_ctc/train.py --num-epochs 10 --start-epoch 3
It loads checkpoint ``./conformer_ctc/exp/epoch-2.pt`` and starts
training from epoch 3. Also, it trains for 10 epochs.
Decoding
--------
The decoding part uses checkpoints saved by the training part, so you have
to run the training part first.
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./conformer_ctc/decode.py --help
shows the options for decoding.
The commonly used options are:
- ``--method``
This specifies the decoding method.
The following command uses attention decoder for rescoring:
.. code-block::
$ cd egs/librispeech/ASR
$ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --lattice-score-scale 0.5
- ``--lattice-score-scale``
It is used to scaled down lattice scores so that we can more unique
paths for rescoring.
- ``--max-duration``
It has the same meaning as the one during training. A larger
value may cause OOM.
Pre-trained Model
-----------------
We have uploaded the pre-trained model to
`<https://huggingface.co/pkufool/icefall_asr_librispeech_conformer_ctc>`_.
We describe how to use the pre-trained model to transcribe a sound file or
multiple sound files in the following.
Install kaldifeat
~~~~~~~~~~~~~~~~~
`kaldifeat <https://github.com/csukuangfj/kaldifeat>`_ is used to
extract features for a single sound file or multiple soundfiles
at the same time.
Please refer to `<https://github.com/csukuangfj/kaldifeat>`_ for installation.
Download the pre-trained model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The following commands describe how to download the pre-trained model:
.. code-block::
$ cd egs/librispeech/ASR
$ mkdir tmp
$ cd tmp
$ git lfs install
$ git clone https://huggingface.co/pkufool/icefall_asr_librispeech_conformer_ctc
.. CAUTION::
You have to use ``git lfs`` to download the pre-trained model.
After downloading, you will have the following files:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ tree tmp
.. code-block:: bash
tmp
`-- icefall_asr_librispeech_conformer_ctc
|-- README.md
|-- data
| |-- lang_bpe
| | |-- HLG.pt
| | |-- bpe.model
| | |-- tokens.txt
| | `-- words.txt
| `-- lm
| `-- G_4_gram.pt
|-- exp
| `-- pretrained.pt
`-- test_wavs
|-- 1089-134686-0001.flac
|-- 1221-135766-0001.flac
|-- 1221-135766-0002.flac
`-- trans.txt
6 directories, 11 files
**File descriptions**:
- ``data/lang_bpe/HLG.pt``
It is the decoding graph.
- ``data/lang_bpe/bpe.model``
It is a sentencepiece model. You can use it to reproduce our results.
- ``data/lang_bpe/tokens.txt``
It contains tokens and their IDs, generated from ``bpe.model``.
Provided only for convenience so that you can look up the SOS/EOS ID easily.
- ``data/lang_bpe/words.txt``
It contains words and their IDs.
- ``data/lm/G_4_gram.pt``
It is a 4-gram LM, useful for LM rescoring.
- ``exp/pretrained.pt``
It contains pre-trained model parameters, obtained by averaging
checkpoints from ``epoch-15.pt`` to ``epoch-34.pt``.
Note: We have removed optimizer ``state_dict`` to reduce file size.
- ``test_waves/*.flac``
It contains some test sound files from LibriSpeech ``test-clean`` dataset.
- `test_waves/trans.txt`
It contains the reference transcripts for the sound files in `test_waves/`.
The information of the test sound files is listed below:
.. code-block:: bash
$ soxi tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/*.flac
Input File : 'tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:06.62 = 106000 samples ~ 496.875 CDDA sectors
File Size : 116k
Bit Rate : 140k
Sample Encoding: 16-bit FLAC
Input File : 'tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:16.71 = 267440 samples ~ 1253.62 CDDA sectors
File Size : 343k
Bit Rate : 164k
Sample Encoding: 16-bit FLAC
Input File : 'tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:04.83 = 77200 samples ~ 361.875 CDDA sectors
File Size : 105k
Bit Rate : 174k
Sample Encoding: 16-bit FLAC
Total Duration of 3 files: 00:00:28.16
Usage
~~~~~
.. code-block::
$ cd egs/librispeech/ASR
$ ./conformer_ctc/pretrained.py --help
displays the help information.
It supports three decoding methods:
- HLG decoding
- HLG + n-gram LM rescoring
- HLG + n-gram LM rescoring + attention decoder rescoring
HLG decoding
^^^^^^^^^^^^
HLG decoding uses the best path of the decoding lattice as the decoding result.
The command to run HLG decoding is:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./conformer_ctc/pretrained.py \
--checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \
--words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \
--HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac
The output is given below:
.. code-block::
2021-08-20 11:03:05,712 INFO [pretrained.py:217] device: cuda:0
2021-08-20 11:03:05,712 INFO [pretrained.py:219] Creating model
2021-08-20 11:03:11,345 INFO [pretrained.py:238] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt
2021-08-20 11:03:18,442 INFO [pretrained.py:255] Constructing Fbank computer
2021-08-20 11:03:18,444 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-08-20 11:03:18,507 INFO [pretrained.py:271] Decoding started
2021-08-20 11:03:18,795 INFO [pretrained.py:300] Use HLG decoding
2021-08-20 11:03:19,149 INFO [pretrained.py:339]
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED
BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-20 11:03:19,149 INFO [pretrained.py:341] Decoding Done
HLG decoding + LM rescoring
^^^^^^^^^^^^^^^^^^^^^^^^^^^
It uses an n-gram LM to rescore the decoding lattice and the best
path of the rescored lattice is the decoding result.
The command to run HLG decoding + LM rescoring is:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./conformer_ctc/pretrained.py \
--checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \
--words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \
--HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \
--method whole-lattice-rescoring \
--G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 0.8 \
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac
Its output is:
.. code-block::
2021-08-20 11:12:17,565 INFO [pretrained.py:217] device: cuda:0
2021-08-20 11:12:17,565 INFO [pretrained.py:219] Creating model
2021-08-20 11:12:23,728 INFO [pretrained.py:238] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt
2021-08-20 11:12:30,035 INFO [pretrained.py:246] Loading G from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt
2021-08-20 11:13:10,779 INFO [pretrained.py:255] Constructing Fbank computer
2021-08-20 11:13:10,787 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-08-20 11:13:10,798 INFO [pretrained.py:271] Decoding started
2021-08-20 11:13:11,085 INFO [pretrained.py:305] Use HLG decoding + LM rescoring
2021-08-20 11:13:11,736 INFO [pretrained.py:339]
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED
BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-20 11:13:11,737 INFO [pretrained.py:341] Decoding Done
HLG decoding + LM rescoring + attention decoder rescoring
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
It uses an n-gram LM to rescore the decoding lattice, extracts
n paths from the rescored lattice, recores the extracted paths with
an attention decoder. The path with the highest score is the decoding result.
The command to run HLG decoding + LM rescoring + attention decoder rescoring is:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./conformer_ctc/pretrained.py \
--checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \
--words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \
--HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \
--method attention-decoder \
--G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 1.3 \
--attention-decoder-scale 1.2 \
--lattice-score-scale 0.5 \
--num-paths 100 \
--sos-id 1 \
--eos-id 1 \
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac
The output is below:
.. code-block::
2021-08-20 11:19:11,397 INFO [pretrained.py:217] device: cuda:0
2021-08-20 11:19:11,397 INFO [pretrained.py:219] Creating model
2021-08-20 11:19:17,354 INFO [pretrained.py:238] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt
2021-08-20 11:19:24,615 INFO [pretrained.py:246] Loading G from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt
2021-08-20 11:20:04,576 INFO [pretrained.py:255] Constructing Fbank computer
2021-08-20 11:20:04,584 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-08-20 11:20:04,595 INFO [pretrained.py:271] Decoding started
2021-08-20 11:20:04,854 INFO [pretrained.py:313] Use HLG + LM rescoring + attention decoder rescoring
2021-08-20 11:20:05,805 INFO [pretrained.py:339]
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED
BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-20 11:20:05,805 INFO [pretrained.py:341] Decoding Done
Colab notebook
--------------
We do provide a colab notebook for this recipe showing how to use a pre-trained model.
|librispeech asr conformer ctc colab notebook|
.. |librispeech asr conformer ctc colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/drive/1huyupXAcHsUrKaWfI83iMEJ6J0Nh0213?usp=sharing
.. HINT::
Due to limited memory provided by Colab, you have to upgrade to Colab Pro to
run ``HLG decoding + LM rescoring`` and
``HLG decoding + LM rescoring + attention decoder rescoring``.
Otherwise, you can only run ``HLG decoding`` with Colab.
**Congratulations!** You have finished the librispeech ASR recipe with
conformer CTC models in ``icefall``.

Binary file not shown.

After

Width:  |  Height:  |  Size: 422 KiB

View File

@ -0,0 +1,322 @@
TDNN-LSTM-CTC
=============
This tutorial shows you how to run a TDNN-LSTM-CTC model with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
.. HINT::
We assume you have read the page :ref:`install icefall` and have setup
the environment for ``icefall``.
Data preparation
----------------
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./prepare.sh
The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
All you need to do is to run it.
The data preparation contains several stages, you can use the following two
options:
- ``--stage``
- ``--stop-stage``
to control which stage(s) should be run. By default, all stages are executed.
For example,
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./prepare.sh --stage 0 --stop-stage 0
means to run only stage 0.
To run stage 2 to stage 5, use:
.. code-block:: bash
$ ./prepare.sh --stage 2 --stop-stage 5
Training
--------
Now describing the training of TDNN-LSTM-CTC model, contained in
the `tdnn_lstm_ctc <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/tdnn_lstm_ctc>`_
folder.
The command to run the training part is:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ export CUDA_VISIBLE_DEVICES="0,1,2,3"
$ ./tdnn_lstm_ctc/train.py --world-size 4
By default, it will run ``20`` epochs. Training logs and checkpoints are saved
in ``tdnn_lstm_ctc/exp``.
In ``tdnn_lstm_ctc/exp``, you will find the following files:
- ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-19.pt``
These are checkpoint files, containing model ``state_dict`` and optimizer ``state_dict``.
To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
.. code-block:: bash
$ ./tdnn_lstm_ctc/train.py --start-epoch 11
- ``tensorboard/``
This folder contains TensorBoard logs. Training loss, validation loss, learning
rate, etc, are recorded in these logs. You can visualize them by:
.. code-block:: bash
$ cd tdnn_lstm_ctc/exp/tensorboard
$ tensorboard dev upload --logdir . --description "TDNN LSTM training for librispeech with icefall"
- ``log/log-train-xxxx``
It is the detailed training log in text format, same as the one
you saw printed to the console during training.
To see available training options, you can use:
.. code-block:: bash
$ ./tdnn_lstm_ctc/train.py --help
Other training options, e.g., learning rate, results dir, etc., are
pre-configured in the function ``get_params()``
in `tdnn_lstm_ctc/train.py <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/tdnn_lstm_ctc/train.py>`_.
Normally, you don't need to change them. You can change them by modifying the code, if
you want.
Decoding
--------
The decoding part uses checkpoints saved by the training part, so you have
to run the training part first.
The command for decoding is:
.. code-block:: bash
$ export CUDA_VISIBLE_DEVICES="0"
$ ./tdnn_lstm_ctc/decode.py
You will see the WER in the output log.
Decoded results are saved in ``tdnn_lstm_ctc/exp``.
.. code-block:: bash
$ ./tdnn_lstm_ctc/decode.py --help
shows you the available decoding options.
Some commonly used options are:
- ``--epoch``
You can select which checkpoint to be used for decoding.
For instance, ``./tdnn_lstm_ctc/decode.py --epoch 10`` means to use
``./tdnn_lstm_ctc/exp/epoch-10.pt`` for decoding.
- ``--avg``
It's related to model averaging. It specifies number of checkpoints
to be averaged. The averaged model is used for decoding.
For example, the following command:
.. code-block:: bash
$ ./tdnn_lstm_ctc/decode.py --epoch 10 --avg 3
uses the average of ``epoch-8.pt``, ``epoch-9.pt`` and ``epoch-10.pt``
for decoding.
- ``--export``
If it is ``True``, i.e., ``./tdnn_lstm_ctc/decode.py --export 1``, the code
will save the averaged model to ``tdnn_lstm_ctc/exp/pretrained.pt``.
See :ref:`tdnn_lstm_ctc use a pre-trained model` for how to use it.
.. HINT::
There are several decoding methods provided in `tdnn_lstm_ctc/decode.py <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/tdnn_lstm_ctc/train.py>`_, you can change the decoding method by modifying ``method`` parameter in function ``get_params()``.
.. _tdnn_lstm_ctc use a pre-trained model:
Pre-trained Model
-----------------
We have uploaded the pre-trained model to
`<https://huggingface.co/pkufool/icefall_asr_librispeech_tdnn-lstm_ctc>`_.
The following shows you how to use the pre-trained model.
Download the pre-trained model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
$ cd egs/librispeech/ASR
$ mkdir tmp
$ cd tmp
$ git lfs install
$ git clone https://huggingface.co/pkufool/icefall_asr_librispeech_tdnn-lstm_ctc
.. CAUTION::
You have to use ``git lfs`` to download the pre-trained model.
After downloading, you will have the following files:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ tree tmp
.. code-block:: bash
tmp/
`-- icefall_asr_librispeech_tdnn-lstm_ctc
|-- README.md
|-- data
| |-- lang_phone
| | |-- HLG.pt
| | |-- tokens.txt
| | `-- words.txt
| `-- lm
| `-- G_4_gram.pt
|-- exp
| `-- pretrained.pt
`-- test_wavs
|-- 1089-134686-0001.flac
|-- 1221-135766-0001.flac
|-- 1221-135766-0002.flac
`-- trans.txt
6 directories, 10 files
Download kaldifeat
~~~~~~~~~~~~~~~~~~
`kaldifeat <https://github.com/csukuangfj/kaldifeat>`_ is used for extracting
features from a single or multiple sound files. Please refer to
`<https://github.com/csukuangfj/kaldifeat>`_ to install ``kaldifeat`` first.
Inference with a pre-trained model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./tdnn_lstm_ctc/pretrained.py --help
shows the usage information of ``./tdnn_lstm_ctc/pretrained.py``.
To decode with ``1best`` method, we can use:
.. code-block:: bash
./tdnn_lstm_ctc/pretrained.py \
--checkpoint ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretraind.pt \
--words-file ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/words.txt \
--HLG ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt \
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac \
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac \
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac
The output is:
.. code-block::
2021-08-24 16:57:13,315 INFO [pretrained.py:168] device: cuda:0
2021-08-24 16:57:13,315 INFO [pretrained.py:170] Creating model
2021-08-24 16:57:18,331 INFO [pretrained.py:182] Loading HLG from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt
2021-08-24 16:57:27,581 INFO [pretrained.py:199] Constructing Fbank computer
2021-08-24 16:57:27,584 INFO [pretrained.py:209] Reading sound files: ['./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac']
2021-08-24 16:57:27,599 INFO [pretrained.py:215] Decoding started
2021-08-24 16:57:27,791 INFO [pretrained.py:245] Use HLG decoding
2021-08-24 16:57:28,098 INFO [pretrained.py:266]
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-24 16:57:28,099 INFO [pretrained.py:268] Decoding Done
To decode with ``whole-lattice-rescoring`` methond, you can use
.. code-block:: bash
./tdnn_lstm_ctc/pretrained.py \
--checkpoint ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretraind.pt \
--words-file ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/words.txt \
--HLG ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt \
--method whole-lattice-rescoring \
--G ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 0.8 \
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac \
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac \
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac
The decoding output is:
.. code-block::
2021-08-24 16:39:24,725 INFO [pretrained.py:168] device: cuda:0
2021-08-24 16:39:24,725 INFO [pretrained.py:170] Creating model
2021-08-24 16:39:29,403 INFO [pretrained.py:182] Loading HLG from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt
2021-08-24 16:39:40,631 INFO [pretrained.py:190] Loading G from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lm/G_4_gram.pt
2021-08-24 16:39:53,098 INFO [pretrained.py:199] Constructing Fbank computer
2021-08-24 16:39:53,107 INFO [pretrained.py:209] Reading sound files: ['./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac']
2021-08-24 16:39:53,121 INFO [pretrained.py:215] Decoding started
2021-08-24 16:39:53,443 INFO [pretrained.py:250] Use HLG decoding + LM rescoring
2021-08-24 16:39:54,010 INFO [pretrained.py:266]
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-24 16:39:54,010 INFO [pretrained.py:268] Decoding Done
Colab notebook
--------------
We provide a colab notebook for decoding with pre-trained model.
|librispeech tdnn_lstm_ctc colab notebook|
.. |librispeech tdnn_lstm_ctc colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd
**Congratulations!** You have finished the TDNN-LSTM-CTC recipe on librispeech in ``icefall``.

View File

@ -0,0 +1,445 @@
yesno
=====
This page shows you how to run the `yesno <https://www.openslr.org/1>`_ recipe. It contains:
- (1) Prepare data for training
- (2) Train a TDNN model
- (a) View text format logs and visualize TensorBoard logs
- (b) Select device type, i.e., CPU and GPU, for training
- (c) Change training options
- (d) Resume training from a checkpoint
- (3) Decode with a trained model
- (a) Select a checkpoint for decoding
- (b) Model averaging
- (4) Colab notebook
- (a) It shows you step by step how to setup the environment, how to do training,
and how to do decoding
- (b) How to use a pre-trained model
- (5) Inference with a pre-trained model
- (a) Download a pre-trained model, provided by us
- (b) Decode a single sound file with a pre-trained model
- (c) Decode multiple sound files at the same time
It does **NOT** show you:
- (1) How to train with multiple GPUs
The ``yesno`` dataset is so small that CPU is more than enough
for training as well as for decoding.
- (2) How to use LM rescoring for decoding
The dataset does not have an LM for rescoring.
.. HINT::
We assume you have read the page :ref:`install icefall` and have setup
the environment for ``icefall``.
.. HINT::
You **don't** need a **GPU** to run this recipe. It can be run on a **CPU**.
The training part takes less than 30 **seconds** on a CPU and you will get
the following WER at the end::
[test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
Data preparation
----------------
.. code-block:: bash
$ cd egs/yesno/ASR
$ ./prepare.sh
The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
All you need to do is to run it.
The data preparation contains several stages, you can use the following two
options:
- ``--stage``
- ``--stop-stage``
to control which stage(s) should be run. By default, all stages are executed.
For example,
.. code-block:: bash
$ cd egs/yesno/ASR
$ ./prepare.sh --stage 0 --stop-stage 0
means to run only stage 0.
To run stage 2 to stage 5, use:
.. code-block:: bash
$ ./prepare.sh --stage 2 --stop-stage 5
Training
--------
We provide only a TDNN model, contained in
the `tdnn <https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn>`_
folder, for ``yesno``.
The command to run the training part is:
.. code-block:: bash
$ cd egs/yesno/ASR
$ export CUDA_VISIBLE_DEVICES=""
$ ./tdnn/train.py
By default, it will run ``15`` epochs. Training logs and checkpoints are saved
in ``tdnn/exp``.
In ``tdnn/exp``, you will find the following files:
- ``epoch-0.pt``, ``epoch-1.pt``, ...
These are checkpoint files, containing model ``state_dict`` and optimizer ``state_dict``.
To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
.. code-block:: bash
$ ./tdnn/train.py --start-epoch 11
- ``tensorboard/``
This folder contains TensorBoard logs. Training loss, validation loss, learning
rate, etc, are recorded in these logs. You can visualize them by:
.. code-block:: bash
$ cd tdnn/exp/tensorboard
$ tensorboard dev upload --logdir . --description "TDNN training for yesno with icefall"
It will print something like below:
.. code-block::
TensorFlow installation not found - running with reduced feature set.
Upload started and will continue reading any new data as it's added to the logdir.
To stop uploading, press Ctrl-C.
New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/yKUbhb5wRmOSXYkId1z9eg/
[2021-08-23T23:49:41] Started scanning logdir.
[2021-08-23T23:49:42] Total uploaded: 135 scalars, 0 tensors, 0 binary objects
Listening for new data in logdir...
Note there is a URL in the above output, click it and you will see
the following screenshot:
.. figure:: images/yesno-tdnn-tensorboard-log.png
:width: 600
:alt: TensorBoard screenshot
:align: center
:target: https://tensorboard.dev/experiment/yKUbhb5wRmOSXYkId1z9eg/
TensorBoard screenshot.
- ``log/log-train-xxxx``
It is the detailed training log in text format, same as the one
you saw printed to the console during training.
.. NOTE::
By default, ``./tdnn/train.py`` uses GPU 0 for training if GPUs are available.
If you have two GPUs, say, GPU 0 and GPU 1, and you want to use GPU 1 for
training, you can run:
.. code-block:: bash
$ export CUDA_VISIBLE_DEVICES="1"
$ ./tdnn/train.py
Since the ``yesno`` dataset is very small, containing only 30 sound files
for training, and the model in use is also very small, we use:
.. code-block:: bash
$ export CUDA_VISIBLE_DEVICES=""
so that ``./tdnn/train.py`` uses CPU during training.
If you don't have GPUs, then you don't need to
run ``export CUDA_VISIBLE_DEVICES=""``.
To see available training options, you can use:
.. code-block:: bash
$ ./tdnn/train.py --help
Other training options, e.g., learning rate, results dir, etc., are
pre-configured in the function ``get_params()``
in `tdnn/train.py <https://github.com/k2-fsa/icefall/blob/master/egs/yesno/ASR/tdnn/train.py>`_.
Normally, you don't need to change them. You can change them by modifying the code, if
you want.
Decoding
--------
The decoding part uses checkpoints saved by the training part, so you have
to run the training part first.
The command for decoding is:
.. code-block:: bash
$ export CUDA_VISIBLE_DEVICES=""
$ ./tdnn/decode.py
You will see the WER in the output log.
Decoded results are saved in ``tdnn/exp``.
.. code-block:: bash
$ ./tdnn/decode.py --help
shows you the available decoding options.
Some commonly used options are:
- ``--epoch``
You can select which checkpoint to be used for decoding.
For instance, ``./tdnn/decode.py --epoch 10`` means to use
``./tdnn/exp/epoch-10.pt`` for decoding.
- ``--avg``
It's related to model averaging. It specifies number of checkpoints
to be averaged. The averaged model is used for decoding.
For example, the following command:
.. code-block:: bash
$ ./tdnn/decode.py --epoch 10 --avg 3
uses the average of ``epoch-8.pt``, ``epoch-9.pt`` and ``epoch-10.pt``
for decoding.
- ``--export``
If it is ``True``, i.e., ``./tdnn/decode.py --export 1``, the code
will save the averaged model to ``tdnn/exp/pretrained.pt``.
See :ref:`yesno use a pre-trained model` for how to use it.
.. _yesno use a pre-trained model:
Pre-trained Model
-----------------
We have uploaded the pre-trained model to
`<https://huggingface.co/csukuangfj/icefall_asr_yesno_tdnn>`_.
The following shows you how to use the pre-trained model.
Download the pre-trained model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
$ cd egs/yesno/ASR
$ mkdir tmp
$ cd tmp
$ git lfs install
$ git clone https://huggingface.co/csukuangfj/icefall_asr_yesno_tdnn
.. CAUTION::
You have to use ``git lfs`` to download the pre-trained model.
After downloading, you will have the following files:
.. code-block:: bash
$ cd egs/yesno/ASR
$ tree tmp
.. code-block:: bash
tmp/
`-- icefall_asr_yesno_tdnn
|-- README.md
|-- lang_phone
| |-- HLG.pt
| |-- L.pt
| |-- L_disambig.pt
| |-- Linv.pt
| |-- lexicon.txt
| |-- lexicon_disambig.txt
| |-- tokens.txt
| `-- words.txt
|-- lm
| |-- G.arpa
| `-- G.fst.txt
|-- pretrained.pt
`-- test_waves
|-- 0_0_0_1_0_0_0_1.wav
|-- 0_0_1_0_0_0_1_0.wav
|-- 0_0_1_0_0_1_1_1.wav
|-- 0_0_1_0_1_0_0_1.wav
|-- 0_0_1_1_0_0_0_1.wav
|-- 0_0_1_1_0_1_1_0.wav
|-- 0_0_1_1_1_0_0_0.wav
|-- 0_0_1_1_1_1_0_0.wav
|-- 0_1_0_0_0_1_0_0.wav
|-- 0_1_0_0_1_0_1_0.wav
|-- 0_1_0_1_0_0_0_0.wav
|-- 0_1_0_1_1_1_0_0.wav
|-- 0_1_1_0_0_1_1_1.wav
|-- 0_1_1_1_0_0_1_0.wav
|-- 0_1_1_1_1_0_1_0.wav
|-- 1_0_0_0_0_0_0_0.wav
|-- 1_0_0_0_0_0_1_1.wav
|-- 1_0_0_1_0_1_1_1.wav
|-- 1_0_1_1_0_1_1_1.wav
|-- 1_0_1_1_1_1_0_1.wav
|-- 1_1_0_0_0_1_1_1.wav
|-- 1_1_0_0_1_0_1_1.wav
|-- 1_1_0_1_0_1_0_0.wav
|-- 1_1_0_1_1_0_0_1.wav
|-- 1_1_0_1_1_1_1_0.wav
|-- 1_1_1_0_0_1_0_1.wav
|-- 1_1_1_0_1_0_1_0.wav
|-- 1_1_1_1_0_0_1_0.wav
|-- 1_1_1_1_1_0_0_0.wav
`-- 1_1_1_1_1_1_1_1.wav
4 directories, 42 files
.. code-block:: bash
$ soxi tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav
Input File : 'tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav'
Channels : 1
Sample Rate : 8000
Precision : 16-bit
Duration : 00:00:06.76 = 54080 samples ~ 507 CDDA sectors
File Size : 108k
Bit Rate : 128k
Sample Encoding: 16-bit Signed Integer PCM
- ``0_0_1_0_1_0_0_1.wav``
0 means No; 1 means Yes. No and Yes are not in English,
but in `Hebrew <https://en.wikipedia.org/wiki/Hebrew_language>`_.
So this file contains ``NO NO YES NO YES NO NO YES``.
Download kaldifeat
~~~~~~~~~~~~~~~~~~
`kaldifeat <https://github.com/csukuangfj/kaldifeat>`_ is used for extracting
features from a single or multiple sound files. Please refer to
`<https://github.com/csukuangfj/kaldifeat>`_ to install ``kaldifeat`` first.
Inference with a pre-trained model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
$ cd egs/yesno/ASR
$ ./tdnn/pretrained.py --help
shows the usage information of ``./tdnn/pretrained.py``.
To decode a single file, we can use:
.. code-block:: bash
./tdnn/pretrained.py \
--checkpoint ./tmp/icefall_asr_yesno_tdnn/pretrained.pt \
--words-file ./tmp/icefall_asr_yesno_tdnn/lang_phone/words.txt \
--HLG ./tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt \
./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav
The output is:
.. code-block::
2021-08-24 12:22:51,621 INFO [pretrained.py:119] {'feature_dim': 23, 'num_classes': 4, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'checkpoint': './tmp/icefall_asr_yesno_tdnn/pretrained.pt', 'words_file': './tmp/icefall_asr_yesno_tdnn/lang_phone/words.txt', 'HLG': './tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt', 'sound_files': ['./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav']}
2021-08-24 12:22:51,645 INFO [pretrained.py:125] device: cpu
2021-08-24 12:22:51,645 INFO [pretrained.py:127] Creating model
2021-08-24 12:22:51,650 INFO [pretrained.py:139] Loading HLG from ./tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt
2021-08-24 12:22:51,651 INFO [pretrained.py:143] Constructing Fbank computer
2021-08-24 12:22:51,652 INFO [pretrained.py:153] Reading sound files: ['./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav']
2021-08-24 12:22:51,684 INFO [pretrained.py:159] Decoding started
2021-08-24 12:22:51,708 INFO [pretrained.py:198]
./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav:
NO NO YES NO YES NO NO YES
2021-08-24 12:22:51,708 INFO [pretrained.py:200] Decoding Done
You can see that for the sound file ``0_0_1_0_1_0_0_1.wav``, the decoding result is
``NO NO YES NO YES NO NO YES``.
To decode **multiple** files at the same time, you can use
.. code-block:: bash
./tdnn/pretrained.py \
--checkpoint ./tmp/icefall_asr_yesno_tdnn/pretrained.pt \
--words-file ./tmp/icefall_asr_yesno_tdnn/lang_phone/words.txt \
--HLG ./tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt \
./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav \
./tmp/icefall_asr_yesno_tdnn/test_waves/1_0_1_1_0_1_1_1.wav
The decoding output is:
.. code-block::
2021-08-24 12:25:20,159 INFO [pretrained.py:119] {'feature_dim': 23, 'num_classes': 4, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'checkpoint': './tmp/icefall_asr_yesno_tdnn/pretrained.pt', 'words_file': './tmp/icefall_asr_yesno_tdnn/lang_phone/words.txt', 'HLG': './tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt', 'sound_files': ['./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav', './tmp/icefall_asr_yesno_tdnn/test_waves/1_0_1_1_0_1_1_1.wav']}
2021-08-24 12:25:20,181 INFO [pretrained.py:125] device: cpu
2021-08-24 12:25:20,181 INFO [pretrained.py:127] Creating model
2021-08-24 12:25:20,185 INFO [pretrained.py:139] Loading HLG from ./tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt
2021-08-24 12:25:20,186 INFO [pretrained.py:143] Constructing Fbank computer
2021-08-24 12:25:20,187 INFO [pretrained.py:153] Reading sound files: ['./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav',
'./tmp/icefall_asr_yesno_tdnn/test_waves/1_0_1_1_0_1_1_1.wav']
2021-08-24 12:25:20,213 INFO [pretrained.py:159] Decoding started
2021-08-24 12:25:20,287 INFO [pretrained.py:198]
./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav:
NO NO YES NO YES NO NO YES
./tmp/icefall_asr_yesno_tdnn/test_waves/1_0_1_1_0_1_1_1.wav:
YES NO YES YES NO YES YES YES
2021-08-24 12:25:20,287 INFO [pretrained.py:200] Decoding Done
You can see again that it decodes correctly.
Colab notebook
--------------
We do provide a colab notebook for this recipe.
|yesno colab notebook|
.. |yesno colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing
**Congratulations!** You have finished the simplest speech recognition recipe in ``icefall``.

View File

@ -1,64 +1,3 @@
## Data preparation
If you want to use `./prepare.sh` to download everything for you,
you can just run
```
./prepare.sh
```
If you have pre-downloaded the LibriSpeech dataset, please
read `./prepare.sh` and modify it to point to the location
of your dataset so that it won't re-download it. After modification,
please run
```
./prepare.sh
```
The script `./prepare.sh` prepares features, lexicon, LMs, etc.
All generated files are saved in the folder `./data`.
**HINT:** `./prepare.sh` supports options `--stage` and `--stop-stage`.
## TDNN-LSTM CTC training
The folder `tdnn_lstm_ctc` contains scripts for CTC training
with TDNN-LSTM models.
Pre-configured parameters for training and decoding are set in the function
`get_params()` within `tdnn_lstm_ctc/train.py`
and `tdnn_lstm_ctc/decode.py`.
Parameters that can be passed from the command-line can be found by
```
./tdnn_lstm_ctc/train.py --help
./tdnn_lstm_ctc/decode.py --help
```
If you have 4 GPUs on a machine and want to use GPU 0, 2, 3 for
mutli-GPU training, you can run
```
export CUDA_VISIBLE_DEVICES="0,2,3"
./tdnn_lstm_ctc/train.py \
--master-port 12345 \
--world-size 3
```
If you want to decode by averaging checkpoints `epoch-8.pt`,
`epoch-9.pt` and `epoch-10.pt`, you can run
```
./tdnn_lstm_ctc/decode.py \
--epoch 10 \
--avg 3
```
## Conformer CTC training
The folder `conformer-ctc` contains scripts for CTC training
with conformer models. The steps of running the training and
decoding are similar to `tdnn_lstm_ctc`.
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/librispeech.html>
for how to run models in this recipe.

View File

@ -6,7 +6,7 @@
TensorBoard log is available at https://tensorboard.dev/experiment/GnRzq8WWQW62dK4bklXBTg/#scalars
Pretrained model is available at https://huggingface.co/pkufool/conformer_ctc
Pretrained model is available at https://huggingface.co/pkufool/icefall_asr_librispeech_conformer_ctc
The best decoding results (WER) are listed below, we got this results by averaging models from epoch 15 to 34, and using `attention-decoder` decoder with num_paths equals to 100.
@ -21,3 +21,26 @@ To get more unique paths, we scaled the lattice.scores with 0.5 (see https://git
|test-clean|1.3|1.2|
|test-other|1.2|1.1|
### LibriSpeech training results (Tdnn-Lstm)
#### 2021-08-24
(Wei Kang): Result of phone based Tdnn-Lstm model.
Icefall version: https://github.com/k2-fsa/icefall/commit/caa0b9e9425af27e0c6211048acb55a76ed5d315
Pretrained model is available at https://huggingface.co/pkufool/icefall_asr_librispeech_tdnn-lstm_ctc
The best decoding results (WER) are listed below, we got this results by averaging models from epoch 19 to 14, and using `whole-lattice-rescoring` decoding method.
||test-clean|test-other|
|--|--|--|
|WER| 6.59% | 17.69% |
We searched the lm_score_scale for best results, the scales that produced the WER above are also listed below.
||lm_scale|
|--|--|
|test-clean|0.8|
|test-other|0.9|

View File

@ -1,351 +1,3 @@
# How to use a pre-trained model to transcribe a sound file or multiple sound files
(See the bottom of this document for the link to a colab notebook.)
You need to prepare 4 files:
- a model checkpoint file, e.g., epoch-20.pt
- HLG.pt, the decoding graph
- words.txt, the word symbol table
- a sound file, whose sampling rate has to be 16 kHz.
Supported formats are those supported by `torchaudio.load()`,
e.g., wav and flac.
Also, you need to install `kaldifeat`. Please refer to
<https://github.com/csukuangfj/kaldifeat> for installation.
```bash
./conformer_ctc/pretrained.py --help
```
displays the help information.
## HLG decoding
Once you have the above files ready and have `kaldifeat` installed,
you can run:
```bash
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--HLG /path/to/HLG.pt \
/path/to/your/sound.wav
```
and you will see the transcribed result.
If you want to transcribe multiple files at the same time, you can use:
```bash
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--HLG /path/to/HLG.pt \
/path/to/your/sound1.wav \
/path/to/your/sound2.wav \
/path/to/your/sound3.wav
```
**Note**: This is the fastest decoding method.
## HLG decoding + LM rescoring
`./conformer_ctc/pretrained.py` also supports `whole lattice LM rescoring`
and `attention decoder rescoring`.
To use whole lattice LM rescoring, you also need the following files:
- G.pt, e.g., `data/lm/G_4_gram.pt` if you have run `./prepare.sh`
The command to run decoding with LM rescoring is:
```bash
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--HLG /path/to/HLG.pt \
--method whole-lattice-rescoring \
--G data/lm/G_4_gram.pt \
--ngram-lm-scale 0.8 \
/path/to/your/sound1.wav \
/path/to/your/sound2.wav \
/path/to/your/sound3.wav
```
## HLG Decoding + LM rescoring + attention decoder rescoring
To use attention decoder for rescoring, you need the following extra information:
- sos token ID
- eos token ID
The command to run decoding with attention decoder rescoring is:
```bash
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--HLG /path/to/HLG.pt \
--method attention-decoder \
--G data/lm/G_4_gram.pt \
--ngram-lm-scale 1.3 \
--attention-decoder-scale 1.2 \
--lattice-score-scale 0.5 \
--num-paths 100 \
--sos-id 1 \
--eos-id 1 \
/path/to/your/sound1.wav \
/path/to/your/sound2.wav \
/path/to/your/sound3.wav
```
# Decoding with a pre-trained model in action
We have uploaded a pre-trained model to <https://huggingface.co/pkufool/conformer_ctc>
The following shows the steps about the usage of the provided pre-trained model.
### (1) Download the pre-trained model
```bash
sudo apt-get install git-lfs
cd /path/to/icefall/egs/librispeech/ASR
git lfs install
mkdir tmp
cd tmp
git clone https://huggingface.co/pkufool/conformer_ctc
```
**CAUTION**: You have to install `git-lfst` to download the pre-trained model.
You will find the following files:
```
tmp
`-- conformer_ctc
|-- README.md
|-- data
| |-- lang_bpe
| | |-- HLG.pt
| | |-- bpe.model
| | |-- tokens.txt
| | `-- words.txt
| `-- lm
| `-- G_4_gram.pt
|-- exp
| `-- pretraind.pt
`-- test_wavs
|-- 1089-134686-0001.flac
|-- 1221-135766-0001.flac
|-- 1221-135766-0002.flac
`-- trans.txt
6 directories, 11 files
```
**File descriptions**:
- `data/lang_bpe/HLG.pt`
It is the decoding graph.
- `data/lang_bpe/bpe.model`
It is a sentencepiece model. You can use it to reproduce our results.
- `data/lang_bpe/tokens.txt`
It contains tokens and their IDs, generated from `bpe.model`.
Provided only for convienice so that you can look up the SOS/EOS ID easily.
- `data/lang_bpe/words.txt`
It contains words and their IDs.
- `data/lm/G_4_gram.pt`
It is a 4-gram LM, useful for LM rescoring.
- `exp/pretrained.pt`
It contains pre-trained model parameters, obtained by averaging
checkpoints from `epoch-15.pt` to `epoch-34.pt`.
Note: We have removed optimizer `state_dict` to reduce file size.
- `test_waves/*.flac`
It contains some test sound files from LibriSpeech `test-clean` dataset.
- `test_waves/trans.txt`
It contains the reference transcripts for the sound files in `test_waves/`.
The information of the test sound files is listed below:
```
$ soxi tmp/conformer_ctc/test_wavs/*.flac
Input File : 'tmp/conformer_ctc/test_wavs/1089-134686-0001.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:06.62 = 106000 samples ~ 496.875 CDDA sectors
File Size : 116k
Bit Rate : 140k
Sample Encoding: 16-bit FLAC
Input File : 'tmp/conformer_ctc/test_wavs/1221-135766-0001.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:16.71 = 267440 samples ~ 1253.62 CDDA sectors
File Size : 343k
Bit Rate : 164k
Sample Encoding: 16-bit FLAC
Input File : 'tmp/conformer_ctc/test_wavs/1221-135766-0002.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:04.83 = 77200 samples ~ 361.875 CDDA sectors
File Size : 105k
Bit Rate : 174k
Sample Encoding: 16-bit FLAC
Total Duration of 3 files: 00:00:28.16
```
### (2) Use HLG decoding
```bash
cd /path/to/icefall/egs/librispeech/ASR
./conformer_ctc/pretrained.py \
--checkpoint ./tmp/conformer_ctc/exp/pretraind.pt \
--words-file ./tmp/conformer_ctc/data/lang_bpe/words.txt \
--HLG ./tmp/conformer_ctc/data/lang_bpe/HLG.pt \
./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac \
./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac \
./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac
```
The output is given below:
```
2021-08-20 11:03:05,712 INFO [pretrained.py:217] device: cuda:0
2021-08-20 11:03:05,712 INFO [pretrained.py:219] Creating model
2021-08-20 11:03:11,345 INFO [pretrained.py:238] Loading HLG from ./tmp/conformer_ctc/data/lang_bpe/HLG.pt
2021-08-20 11:03:18,442 INFO [pretrained.py:255] Constructing Fbank computer
2021-08-20 11:03:18,444 INFO [pretrained.py:265] Reading sound files: ['./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-08-20 11:03:18,507 INFO [pretrained.py:271] Decoding started
2021-08-20 11:03:18,795 INFO [pretrained.py:300] Use HLG decoding
2021-08-20 11:03:19,149 INFO [pretrained.py:339]
./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED
BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-20 11:03:19,149 INFO [pretrained.py:341] Decoding Done
```
### (3) Use HLG decoding + LM rescoring
```bash
./conformer_ctc/pretrained.py \
--checkpoint ./tmp/conformer_ctc/exp/pretraind.pt \
--words-file ./tmp/conformer_ctc/data/lang_bpe/words.txt \
--HLG ./tmp/conformer_ctc/data/lang_bpe/HLG.pt \
--method whole-lattice-rescoring \
--G ./tmp/conformer_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 0.8 \
./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac \
./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac \
./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac
```
The output is:
```
2021-08-20 11:12:17,565 INFO [pretrained.py:217] device: cuda:0
2021-08-20 11:12:17,565 INFO [pretrained.py:219] Creating model
2021-08-20 11:12:23,728 INFO [pretrained.py:238] Loading HLG from ./tmp/conformer_ctc/data/lang_bpe/HLG.pt
2021-08-20 11:12:30,035 INFO [pretrained.py:246] Loading G from ./tmp/conformer_ctc/data/lm/G_4_gram.pt
2021-08-20 11:13:10,779 INFO [pretrained.py:255] Constructing Fbank computer
2021-08-20 11:13:10,787 INFO [pretrained.py:265] Reading sound files: ['./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-08-20 11:13:10,798 INFO [pretrained.py:271] Decoding started
2021-08-20 11:13:11,085 INFO [pretrained.py:305] Use HLG decoding + LM rescoring
2021-08-20 11:13:11,736 INFO [pretrained.py:339]
./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED
BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-20 11:13:11,737 INFO [pretrained.py:341] Decoding Done
```
### (4) Use HLG decoding + LM rescoring + attention decoder rescoring
```bash
./conformer_ctc/pretrained.py \
--checkpoint ./tmp/conformer_ctc/exp/pretraind.pt \
--words-file ./tmp/conformer_ctc/data/lang_bpe/words.txt \
--HLG ./tmp/conformer_ctc/data/lang_bpe/HLG.pt \
--method attention-decoder \
--G ./tmp/conformer_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 1.3 \
--attention-decoder-scale 1.2 \
--lattice-score-scale 0.5 \
--num-paths 100 \
--sos-id 1 \
--eos-id 1 \
./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac \
./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac \
./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac
```
The output is:
```
2021-08-20 11:19:11,397 INFO [pretrained.py:217] device: cuda:0
2021-08-20 11:19:11,397 INFO [pretrained.py:219] Creating model
2021-08-20 11:19:17,354 INFO [pretrained.py:238] Loading HLG from ./tmp/conformer_ctc/data/lang_bpe/HLG.pt
2021-08-20 11:19:24,615 INFO [pretrained.py:246] Loading G from ./tmp/conformer_ctc/data/lm/G_4_gram.pt
2021-08-20 11:20:04,576 INFO [pretrained.py:255] Constructing Fbank computer
2021-08-20 11:20:04,584 INFO [pretrained.py:265] Reading sound files: ['./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-08-20 11:20:04,595 INFO [pretrained.py:271] Decoding started
2021-08-20 11:20:04,854 INFO [pretrained.py:313] Use HLG + LM rescoring + attention decoder rescoring
2021-08-20 11:20:05,805 INFO [pretrained.py:339]
./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED
BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-20 11:20:05,805 INFO [pretrained.py:341] Decoding Done
```
**NOTE**: We provide a colab notebook for demonstration.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1huyupXAcHsUrKaWfI83iMEJ6J0Nh0213?usp=sharing)
Due to limited memory provided by Colab, you have to upgrade to Colab Pro to
run `HLG decoding + LM rescoring` and `HLG decoding + LM rescoring + attention decoder rescoring`.
Otherwise, you can only run `HLG decoding` with Colab.
Please visit
<https://icefall.readthedocs.io/en/latest/recipes/librispeech/conformer_ctc.html>
for how to run this recipe.

View File

@ -1,7 +1,20 @@
#!/usr/bin/env python3
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
# Apache 2.0
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
@ -396,7 +409,7 @@ class RelPositionalEncoding(torch.nn.Module):
:,
self.pe.size(1) // 2
- x.size(1)
+ 1 : self.pe.size(1) // 2
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1),
]
return self.dropout(x), self.dropout(pos_emb)

View File

@ -1,8 +1,20 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# (still working in progress)
import argparse
import logging
@ -45,28 +57,63 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
default=9,
default=34,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
default=20,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--method",
type=str,
default="attention-decoder",
help="""Decoding method.
Supported values are:
- (1) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
- (5) attention-decoder. Extract n paths from the LM rescored
lattice, the path with the highest score is the decoding result.
- (6) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
""",
)
parser.add_argument(
"--lattice-score-scale",
type=float,
default=1.0,
help="The scale to be applied to `lattice.scores`."
"It's needed if you use any kinds of n-best based rescoring. "
"Currently, it is used when the decoding method is: nbest, "
"nbest-rescoring, attention-decoder, and nbest-oracle. "
"A smaller value results in more unique paths.",
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
A smaller value results in more unique paths.
""",
)
return parser
@ -92,21 +139,6 @@ def get_params() -> AttributeDict:
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
# Possible values for method:
# - 1best
# - nbest
# - nbest-rescoring
# - whole-lattice-rescoring
# - attention-decoder
# - nbest-oracle
# "method": "nbest",
# "method": "nbest-rescoring",
# "method": "whole-lattice-rescoring",
"method": "attention-decoder",
# "method": "nbest-oracle",
# num_paths is used when method is "nbest", "nbest-rescoring",
# attention-decoder, and nbest-oracle
"num_paths": 100,
}
)
return params
@ -117,7 +149,7 @@ def decode_one_batch(
model: nn.Module,
HLG: k2.Fsa,
batch: dict,
lexicon: Lexicon,
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
@ -151,8 +183,8 @@ def decode_one_batch(
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
lexicon:
It contains word symbol table.
word_table:
The word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
@ -205,7 +237,7 @@ def decode_one_batch(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
lexicon=lexicon,
word_table=word_table,
scale=params.lattice_score_scale,
)
@ -225,7 +257,7 @@ def decode_one_batch(
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
hyps = [[word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.method in [
@ -271,7 +303,7 @@ def decode_one_batch(
ans = dict()
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
return ans
@ -281,7 +313,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
HLG: k2.Fsa,
lexicon: Lexicon,
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
@ -297,8 +329,8 @@ def decode_dataset(
The neural model.
HLG:
The decoding graph.
lexicon:
It contains word symbol table.
word_table:
It is the word symbol table.
sos_id:
The token ID for SOS.
eos_id:
@ -332,7 +364,7 @@ def decode_dataset(
model=model,
HLG=HLG,
batch=batch,
lexicon=lexicon,
word_table=word_table,
G=G,
sos_id=sos_id,
eos_id=eos_id,
@ -528,7 +560,7 @@ def main():
params=params,
model=model,
HLG=HLG,
lexicon=lexicon,
word_table=lexicon.word_table,
G=G,
sos_id=sos_id,
eos_id=eos_id,

View File

@ -1,4 +1,20 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@ -59,7 +75,7 @@ def get_parser():
decoding lattice and then use 1best to decode the
rescored lattice.
We call it HLG decoding + n-gram LM rescoring.
(3) attention-decoder - Extract n paths from he rescored
(3) attention-decoder - Extract n paths from the rescored
lattice and use the transformer attention decoder for
rescoring.
We call it HLG decoding + n-gram LM rescoring + attention
@ -245,9 +261,9 @@ def main():
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
G = G.to(device)
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = G.to(device)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G.lm_scores = G.scores.clone()
@ -268,7 +284,7 @@ def main():
)
waves = [w.to(device) for w in waves]
logging.info(f"Decoding started")
logging.info("Decoding started")
features = fbank(waves)
features = pad_sequence(
@ -338,7 +354,7 @@ def main():
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info(f"Decoding Done")
logging.info("Decoding Done")
if __name__ == "__main__":

View File

@ -1,3 +1,20 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn

View File

@ -1,4 +1,20 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from subsampling import Conv2dSubsampling
from subsampling import VggSubsampling

View File

@ -1,4 +1,20 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformer import (

View File

@ -1,6 +1,20 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This is just at the very beginning ...
import argparse
import logging
@ -60,6 +74,23 @@ def get_parser():
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=35,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
conformer_ctc/exp/epoch-{start_epoch-1}.pt
""",
)
return parser
@ -89,11 +120,6 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model.
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
and continue training from that checkpoint.
- num_epochs: Number of epochs to train.
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
@ -129,8 +155,6 @@ def get_params() -> AttributeDict:
"feature_dim": 80,
"weight_decay": 1e-6,
"subsampling_factor": 4,
"start_epoch": 0,
"num_epochs": 20,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,

View File

@ -1,5 +1,19 @@
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
# Apache 2.0
# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, List, Optional, Tuple
@ -779,7 +793,8 @@ class Noam(object):
class LabelSmoothingLoss(nn.Module):
"""
Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w)
Label-smoothing loss. KL-divergence between
q_{smoothed ground truth prob.}(w)
and p_{prob. computed by model}(w) is minimized.
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa
@ -864,7 +879,8 @@ def encoder_padding_mask(
frames, before subsampling)
Returns:
Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices.
Tensor: Mask tensor of dimension (batch_size, input_length),
True denote the masked indices.
"""
if supervisions is None:
return None

View File

@ -1,4 +1,20 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input lang_dir and generates HLG from

View File

@ -1,8 +1,24 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the LibriSpeech dataset.
Its looks for manifests in the directory data/manifests.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
@ -17,8 +33,9 @@ from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and
# slow things down. Do this outside of main() in case it needs to take effect
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
@ -53,7 +70,8 @@ def compute_fbank_librispeech():
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"], supervisions=m["supervisions"],
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (

View File

@ -1,8 +1,24 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the musan dataset.
Its looks for manifests in the directory data/manifests.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
@ -17,8 +33,9 @@ from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and
# slow things down. Do this outside of main() in case it needs to take effect
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)

View File

@ -1,6 +1,21 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
"""
This file downloads the following LibriSpeech LM files:

View File

@ -1,6 +1,20 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
"""
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"

View File

@ -1,4 +1,20 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)

View File

@ -1,4 +1,20 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)

View File

@ -1,4 +1,20 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# You can install sentencepiece via:
#

View File

@ -0,0 +1,270 @@
# How to use a pre-trained model to transcribe a sound file or multiple sound files
(See the bottom of this document for the link to a colab notebook.)
You need to prepare 4 files:
- a model checkpoint file, e.g., epoch-20.pt
- HLG.pt, the decoding graph
- words.txt, the word symbol table
- a sound file, whose sampling rate has to be 16 kHz.
Supported formats are those supported by `torchaudio.load()`,
e.g., wav and flac.
Also, you need to install `kaldifeat`. Please refer to
<https://github.com/csukuangfj/kaldifeat> for installation.
```bash
./tdnn_lstm_ctc/pretrained.py --help
```
displays the help information.
## HLG decoding
Once you have the above files ready and have `kaldifeat` installed,
you can run:
```bash
./tdnn_lstm_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--HLG /path/to/HLG.pt \
/path/to/your/sound.wav
```
and you will see the transcribed result.
If you want to transcribe multiple files at the same time, you can use:
```bash
./tdnn_lstm_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--HLG /path/to/HLG.pt \
/path/to/your/sound1.wav \
/path/to/your/sound2.wav \
/path/to/your/sound3.wav
```
**Note**: This is the fastest decoding method.
## HLG decoding + LM rescoring
`./tdnn_lstm_ctc/pretrained.py` also supports `whole lattice LM rescoring`.
To use whole lattice LM rescoring, you also need the following files:
- G.pt, e.g., `data/lm/G_4_gram.pt` if you have run `./prepare.sh`
The command to run decoding with LM rescoring is:
```bash
./tdnn_lstm_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--HLG /path/to/HLG.pt \
--method whole-lattice-rescoring \
--G data/lm/G_4_gram.pt \
--ngram-lm-scale 0.8 \
/path/to/your/sound1.wav \
/path/to/your/sound2.wav \
/path/to/your/sound3.wav
```
# Decoding with a pre-trained model in action
We have uploaded a pre-trained model to <https://huggingface.co/pkufool/icefall_asr_librispeech_tdnn-lstm_ctc>
The following shows the steps about the usage of the provided pre-trained model.
### (1) Download the pre-trained model
```bash
sudo apt-get install git-lfs
cd /path/to/icefall/egs/librispeech/ASR
git lfs install
mkdir tmp
cd tmp
git clone https://huggingface.co/pkufool/icefall_asr_librispeech_tdnn-lstm_ctc
```
**CAUTION**: You have to install `git-lfs` to download the pre-trained model.
You will find the following files:
```
tmp/
`-- icefall_asr_librispeech_tdnn-lstm_ctc
|-- README.md
|-- data
| |-- lang_phone
| | |-- HLG.pt
| | |-- tokens.txt
| | `-- words.txt
| `-- lm
| `-- G_4_gram.pt
|-- exp
| `-- pretrained.pt
`-- test_wavs
|-- 1089-134686-0001.flac
|-- 1221-135766-0001.flac
|-- 1221-135766-0002.flac
`-- trans.txt
6 directories, 10 files
```
**File descriptions**:
- `data/lang_phone/HLG.pt`
It is the decoding graph.
- `data/lang_phone/tokens.txt`
It contains tokens and their IDs.
- `data/lang_phone/words.txt`
It contains words and their IDs.
- `data/lm/G_4_gram.pt`
It is a 4-gram LM, useful for LM rescoring.
- `exp/pretrained.pt`
It contains pre-trained model parameters, obtained by averaging
checkpoints from `epoch-14.pt` to `epoch-19.pt`.
Note: We have removed optimizer `state_dict` to reduce file size.
- `test_waves/*.flac`
It contains some test sound files from LibriSpeech `test-clean` dataset.
- `test_waves/trans.txt`
It contains the reference transcripts for the sound files in `test_waves/`.
The information of the test sound files is listed below:
```
$ soxi tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/*.flac
Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:06.62 = 106000 samples ~ 496.875 CDDA sectors
File Size : 116k
Bit Rate : 140k
Sample Encoding: 16-bit FLAC
Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:16.71 = 267440 samples ~ 1253.62 CDDA sectors
File Size : 343k
Bit Rate : 164k
Sample Encoding: 16-bit FLAC
Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:04.83 = 77200 samples ~ 361.875 CDDA sectors
File Size : 105k
Bit Rate : 174k
Sample Encoding: 16-bit FLAC
Total Duration of 3 files: 00:00:28.16
```
### (2) Use HLG decoding
```bash
cd /path/to/icefall/egs/librispeech/ASR
./tdnn_lstm_ctc/pretrained.py \
--checkpoint ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretraind.pt \
--words-file ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/words.txt \
--HLG ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt \
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac \
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac \
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac
```
The output is given below:
```
2021-08-24 16:57:13,315 INFO [pretrained.py:168] device: cuda:0
2021-08-24 16:57:13,315 INFO [pretrained.py:170] Creating model
2021-08-24 16:57:18,331 INFO [pretrained.py:182] Loading HLG from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt
2021-08-24 16:57:27,581 INFO [pretrained.py:199] Constructing Fbank computer
2021-08-24 16:57:27,584 INFO [pretrained.py:209] Reading sound files: ['./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac']
2021-08-24 16:57:27,599 INFO [pretrained.py:215] Decoding started
2021-08-24 16:57:27,791 INFO [pretrained.py:245] Use HLG decoding
2021-08-24 16:57:28,098 INFO [pretrained.py:266]
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-24 16:57:28,099 INFO [pretrained.py:268] Decoding Done
```
### (3) Use HLG decoding + LM rescoring
```bash
./tdnn_lstm_ctc/pretrained.py \
--checkpoint ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretraind.pt \
--words-file ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/words.txt \
--HLG ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt \
--method whole-lattice-rescoring \
--G ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 0.8 \
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac \
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac \
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac
```
The output is:
```
2021-08-24 16:39:24,725 INFO [pretrained.py:168] device: cuda:0
2021-08-24 16:39:24,725 INFO [pretrained.py:170] Creating model
2021-08-24 16:39:29,403 INFO [pretrained.py:182] Loading HLG from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt
2021-08-24 16:39:40,631 INFO [pretrained.py:190] Loading G from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lm/G_4_gram.pt
2021-08-24 16:39:53,098 INFO [pretrained.py:199] Constructing Fbank computer
2021-08-24 16:39:53,107 INFO [pretrained.py:209] Reading sound files: ['./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac']
2021-08-24 16:39:53,121 INFO [pretrained.py:215] Decoding started
2021-08-24 16:39:53,443 INFO [pretrained.py:250] Use HLG decoding + LM rescoring
2021-08-24 16:39:54,010 INFO [pretrained.py:266]
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-24 16:39:54,010 INFO [pretrained.py:268] Decoding Done
```
**NOTE**: We provide a colab notebook for demonstration.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing)
Due to limited memory provided by Colab, you have to upgrade to Colab Pro to run `HLG decoding + LM rescoring`.
Otherwise, you can only run `HLG decoding` with Colab.

View File

@ -1,2 +1,4 @@
Will add results later.
Please visit
<https://icefall.readthedocs.io/en/latest/recipes/librispeech/tdnn_lstm_ctc.html>
for how to run this recipe.

View File

@ -1,3 +1,20 @@
# Copyright 2021 Piotr Żelasko
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
@ -23,7 +40,7 @@ from icefall.utils import str2bool
class LibriSpeechAsrDataModule(DataModule):
"""
DataModule for K2 ASR experiments.
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).

View File

@ -1,4 +1,19 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
@ -27,6 +42,7 @@ from icefall.utils import (
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
@ -39,7 +55,7 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
default=9,
default=19,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
@ -51,6 +67,16 @@ def get_parser():
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
tdnn/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
return parser
@ -72,7 +98,7 @@ def get_params() -> AttributeDict:
# - nbest
# - nbest-rescoring
# - whole-lattice-rescoring
"method": "1best",
"method": "whole-lattice-rescoring",
# num_paths is used when method is "nbest" and "nbest-rescoring"
"num_paths": 30,
}
@ -333,7 +359,7 @@ def main():
logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(
torch.load("data/lang_phone/HLG.pt", map_location="cpu")
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device)
assert HLG.requires_grad is False
@ -393,6 +419,12 @@ def main():
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
model.to(device)
model.eval()

View File

@ -1,3 +1,20 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn

View File

@ -0,0 +1,277 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from model import TdnnLstm
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import (
get_lattice,
one_best_decoding,
rescore_with_whole_lattice,
)
from icefall.utils import AttributeDict, get_texts
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--words-file",
type=str,
required=True,
help="Path to words.txt",
)
parser.add_argument(
"--HLG", type=str, required=True, help="Path to HLG.pt."
)
parser.add_argument(
"--method",
type=str,
default="1best",
help="""Decoding method.
Possible values are:
(1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding.
We call it HLG decoding.
(2) whole-lattice-rescoring - Use an LM to rescore the
decoding lattice and then use 1best to decode the
rescored lattice.
We call it HLG decoding + n-gram LM rescoring.
""",
)
parser.add_argument(
"--G",
type=str,
help="""An LM for rescoring.
Used only when method is
whole-lattice-rescoring.
It's usually a 4-gram LM.
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.8,
help="""
Used only when method is whole-lattice-rescoring.
It specifies the scale for n-gram LM scores.
(Note: You need to tune it on a dataset.)
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 80,
"subsampling_factor": 3,
"num_classes": 72,
"sample_rate": 16000,
"search_beam": 20,
"output_beam": 5,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = TdnnLstm(
num_features=params.feature_dim,
num_classes=params.num_classes,
subsampling_factor=params.subsampling_factor,
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
HLG.lm_scores = HLG.scores.clone()
if params.method == "whole-lattice-rescoring":
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = G.to(device)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G.lm_scores = G.scores.clone()
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
features = features.permute(0, 2, 1) # now features is [N, C, T]
with torch.no_grad():
nnet_output = model(features)
# nnet_output is [N, T, C]
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
lattice = get_lattice(
nnet_output=nnet_output,
HLG=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "1best":
logging.info("Use HLG decoding")
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
elif params.method == "whole-lattice-rescoring":
logging.info("Use HLG decoding + LM rescoring")
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=[params.ngram_lm_scale],
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1,4 +1,20 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@ -59,6 +75,23 @@ def get_parser():
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=20,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
tdnn_lstm_ctc/exp/epoch-{start_epoch-1}.pt
""",
)
return parser
@ -88,11 +121,6 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model.
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
and continue training from that checkpoint.
- num_epochs: Number of epochs to train.
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
@ -111,6 +139,8 @@ def get_params() -> AttributeDict:
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval` is 0
- beam_size: It is used in k2.ctc_loss
@ -127,14 +157,13 @@ def get_params() -> AttributeDict:
"feature_dim": 80,
"weight_decay": 5e-4,
"subsampling_factor": 3,
"start_epoch": 0,
"num_epochs": 10,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 10,
"reset_interval": 200,
"valid_interval": 1000,
"beam_size": 10,
"reduction": "sum",
@ -382,8 +411,12 @@ def train_one_epoch(
"""
model.train()
tot_loss = 0.0 # sum of losses over all batches
tot_frames = 0.0 # sum of frames over all batches
tot_loss = 0.0 # reset after params.reset_interval of batches
tot_frames = 0.0 # reset after params.reset_interval of batches
params.tot_loss = 0.0
params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -410,6 +443,9 @@ def train_one_epoch(
tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
params.tot_frames += params.train_frames
params.tot_loss += loss_cpu
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
@ -417,6 +453,22 @@ def train_one_epoch(
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0
tot_frames = 0
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss(
@ -433,7 +485,7 @@ def train_one_epoch(
f"best valid epoch: {params.best_valid_epoch}"
)
params.train_loss = tot_loss / tot_frames
params.train_loss = params.tot_loss / params.tot_frames
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch

14
egs/yesno/ASR/README.md Normal file
View File

@ -0,0 +1,14 @@
## Yesno recipe
This is the simplest ASR recipe in `icefall`.
It can be run on CPU and takes less than 30 seconds to
get the following WER:
```
[test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
```
Please refer to
<https://icefall.readthedocs.io/en/latest/recipes/yesno.html>
for detailed instructions.

View File

@ -0,0 +1,134 @@
#!/usr/bin/env python3
"""
This script takes as input lang_dir and generates HLG from
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G.fst.txt
The generated HLG is saved in $lang_dir/HLG.pt
"""
import argparse
import logging
from pathlib import Path
import k2
import torch
from icefall.lexicon import Lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
return parser.parse_args()
def compile_HLG(lang_dir: str) -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
Return:
An FSA representing HLG.
"""
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
logging.info("Loading G.fst.txt")
with open("data/lm/G.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
L = k2.arc_sort(L)
G = k2.arc_sort(G)
logging.info("Intersecting L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}")
logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(f"LG shape after k2.connect: {LG.shape}")
logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")
LG = k2.determinize(LG)
logging.info(type(LG.aux_labels))
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
logging.info("Removing disambiguation symbols on LG")
LG.labels[LG.labels >= first_token_disambig_id] = 0
assert isinstance(LG.aux_labels, k2.RaggedInt)
LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
logging.info("Composing H and LG")
# CAUTION: The name of the inner_labels is fixed
# to `tokens`. If you want to change it, please
# also change other places in icefall that are using
# it.
HLG = k2.compose(H, LG, inner_labels="tokens")
logging.info("Connecting LG")
HLG = k2.connect(HLG)
logging.info("Arc sorting LG")
HLG = k2.arc_sort(HLG)
logging.info(f"HLG.shape: {HLG.shape}")
return HLG
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
if (lang_dir / "HLG.pt").is_file():
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
return
logging.info(f"Processing {lang_dir}")
HLG = compile_HLG(lang_dir)
logging.info(f"Saving HLG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,81 @@
#!/usr/bin/env python3
"""
This file computes fbank features of the yesno dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or it wastes a
# lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_yesno():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
# This dataset is rather small, so we use only one job
num_jobs = min(1, os.cpu_count())
num_mel_bins = 23
dataset_parts = (
"train",
"test",
)
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts, output_dir=src_dir
)
assert manifests is not None
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
if (output_dir / f"cuts_{partition}.json.gz").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 1, # use one job
executor=ex,
storage_type=LilcomHdf5Writer,
)
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_yesno()

View File

@ -0,0 +1,367 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
"""
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
consisting of words and tokens (i.e., phones) and does the following:
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
2. Generate tokens.txt, the token table mapping a token to a unique integer.
3. Generate words.txt, the word table mapping a word to a unique integer.
4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt")
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.
"""
import math
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
import k2
import torch
from icefall.lexicon import read_lexicon, write_lexicon
Lexicon = List[Tuple[str, List[str]]]
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.
Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.
Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items():
f.write(f"{sym} {i}\n")
def get_tokens(lexicon: Lexicon) -> List[str]:
"""Get tokens from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique tokens.
"""
ans = set()
for _, tokens in lexicon:
ans.update(tokens)
sorted_ans = sorted(list(ans))
return sorted_ans
def get_words(lexicon: Lexicon) -> List[str]:
"""Get words from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique words.
"""
ans = set()
for word, _ in lexicon:
ans.add(word)
sorted_ans = sorted(list(ans))
return sorted_ans
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
"""It adds pseudo-token disambiguation symbols #1, #2 and so on
at the ends of tokens to ensure that all pronunciations are different,
and that none is a prefix of another.
See also add_lex_disambig.pl from kaldi.
Args:
lexicon:
It is returned by :func:`read_lexicon`.
Returns:
Return a tuple with two elements:
- The output lexicon with disambiguation symbols
- The ID of the max disambiguation symbol that appears
in the lexicon
"""
# (1) Work out the count of each token-sequence in the
# lexicon.
count = defaultdict(int)
for _, tokens in lexicon:
count[" ".join(tokens)] += 1
# (2) For each left sub-sequence of each token-sequence, note down
# that it exists (for identifying prefixes of longer strings).
issubseq = defaultdict(int)
for _, tokens in lexicon:
tokens = tokens.copy()
tokens.pop()
while tokens:
issubseq[" ".join(tokens)] = 1
tokens.pop()
# (3) For each entry in the lexicon:
# if the token sequence is unique and is not a
# prefix of another word, no disambig symbol.
# Else output #1, or #2, #3, ... if the same token-seq
# has already been assigned a disambig symbol.
ans = []
# We start with #1 since #0 has its own purpose
first_allowed_disambig = 1
max_disambig = first_allowed_disambig - 1
last_used_disambig_symbol_of = defaultdict(int)
for word, tokens in lexicon:
tokenseq = " ".join(tokens)
assert tokenseq != ""
if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
ans.append((word, tokens))
continue
cur_disambig = last_used_disambig_symbol_of[tokenseq]
if cur_disambig == 0:
cur_disambig = first_allowed_disambig
else:
cur_disambig += 1
if cur_disambig > max_disambig:
max_disambig = cur_disambig
last_used_disambig_symbol_of[tokenseq] = cur_disambig
tokenseq += f" #{cur_disambig}"
ans.append((word, tokenseq.split()))
return ans, max_disambig
def generate_id_map(symbols: List[str]) -> Dict[str, int]:
"""Generate ID maps, i.e., map a symbol to a unique ID.
Args:
symbols:
A list of unique symbols.
Returns:
A dict containing the mapping between symbols and IDs.
"""
return {sym: i for i, sym in enumerate(symbols)}
def add_self_loops(
arcs: List[List[Any]], disambig_token: int, disambig_word: int
) -> List[List[Any]]:
"""Adds self-loops to states of an FST to propagate disambiguation symbols
through it. They are added on each state with non-epsilon output symbols
on at least one arc out of the state.
See also fstaddselfloops.pl from Kaldi. One difference is that
Kaldi uses OpenFst style FSTs and it has multiple final states.
This function uses k2 style FSTs and it does not need to add self-loops
to the final state.
The input label of a self-loop is `disambig_token`, while the output
label is `disambig_word`.
Args:
arcs:
A list-of-list. The sublist contains
`[src_state, dest_state, label, aux_label, score]`
disambig_token:
It is the token ID of the symbol `#0`.
disambig_word:
It is the word ID of the symbol `#0`.
Return:
Return new `arcs` containing self-loops.
"""
states_needs_self_loops = set()
for arc in arcs:
src, dst, ilabel, olabel, score = arc
if olabel != 0:
states_needs_self_loops.add(src)
ans = []
for s in states_needs_self_loops:
ans.append([s, s, disambig_token, disambig_word, 0])
return arcs + ans
def lexicon_to_fst(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
sil_token: str = "SIL",
sil_prob: float = 0.5,
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format) with optional silence at
the beginning and end of each word.
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
sil_token:
The silence token.
sil_prob:
The probability for adding a silence at the beginning and end
of the word.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
assert sil_prob > 0.0 and sil_prob < 1.0
# CAUTION: we use score, i.e, negative cost.
sil_score = math.log(sil_prob)
no_sil_score = math.log(1.0 - sil_prob)
start_state = 0
loop_state = 1 # words enter and leave from here
sil_state = 2 # words terminate here when followed by silence; this state
# has a silence transition to loop_state.
next_state = 3 # the next un-allocated state, will be incremented as we go.
arcs = []
assert token2id["<eps>"] == 0
assert word2id["<eps>"] == 0
eps = 0
sil_token = token2id[sil_token]
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
arcs.append([start_state, sil_state, eps, eps, sil_score])
arcs.append([sil_state, loop_state, sil_token, eps, 0])
for word, tokens in lexicon:
assert len(tokens) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
tokens = [token2id[i] for i in tokens]
for i in range(len(tokens) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, tokens[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last token of this word
# It has two out-going arcs, one to the loop state,
# the other one to the sil_state.
i = len(tokens) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def main():
out_dir = Path("data/lang_phone")
lexicon_filename = out_dir / "lexicon.txt"
sil_token = "SIL"
sil_prob = 0.5
lexicon = read_lexicon(lexicon_filename)
tokens = get_tokens(lexicon)
words = get_words(lexicon)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in tokens
tokens.append(f"#{i}")
assert "<eps>" not in tokens
tokens = ["<eps>"] + tokens
assert "<eps>" not in words
assert "#0" not in words
assert "<s>" not in words
assert "</s>" not in words
words = ["<eps>"] + words + ["#0", "<s>", "</s>"]
token2id = generate_id_map(tokens)
word2id = generate_id_map(words)
write_mapping(out_dir / "tokens.txt", token2id)
write_mapping(out_dir / "words.txt", word2id)
write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst(
lexicon,
token2id=token2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
)
L_disambig = lexicon_to_fst(
lexicon_disambig,
token2id=token2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
need_self_loops=True,
)
torch.save(L.as_dict(), out_dir / "L.pt")
torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
if False:
# Just for debugging, will remove it
L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
L_disambig.labels_sym = L.labels_sym
L_disambig.aux_labels_sym = L.aux_labels_sym
L.draw(out_dir / "L.png", title="L")
L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
if __name__ == "__main__":
main()

93
egs/yesno/ASR/prepare.sh Executable file
View File

@ -0,0 +1,93 @@
#!/usr/bin/env bash
set -eou pipefail
stage=-1
stop_stage=100
dl_dir=$PWD/download
lang_dir=data/lang_phone
lm_dir=data/lm
. shared/parse_options.sh || exit 1
mkdir -p $lang_dir
mkdir -p $lm_dir
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: Download data"
mkdir -p $dl_dir
if [ ! -f $dl_dir/waves_yesno/.completed ]; then
lhotse download yesno $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare yesno manifest"
mkdir -p data/manifests
lhotse prepare yesno $dl_dir/waves_yesno data/manifests
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Compute fbank for yesno"
mkdir -p data/fbank
./local/compute_fbank_yesno.py
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare lang"
# NOTE: "<UNK> SIL" is added for implementation convenience
# as the graph compiler code requires that there is a OOV word
# in the lexicon.
(
echo "<SIL> SIL"
echo "YES Y"
echo "NO N"
echo "<UNK> SIL"
) > $lang_dir/lexicon.txt
./local/prepare_lang.py
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Prepare G"
# We use a unigram G
cat <<EOF > $lm_dir/G.arpa
\data\\
ngram 1=4
\1-grams:
-1 NO
-1 YES
-99 <s>
-1 </s>
\end\\
EOF
if [ ! -f $lm_dir/G.fst.txt ]; then
python3 -m kaldilm \
--read-symbol-table="$lang_dir/words.txt" \
--disambig-symbol='#0' \
$lm_dir/G.arpa > $lm_dir/G.fst.txt
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compile HLG"
if [ ! -f $lang_dir/HLG.pt ]; then
./local/compile_hlg.py --lang-dir $lang_dir
fi
fi

1
egs/yesno/ASR/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared/

View File

@ -0,0 +1,8 @@
## How to run this recipe
You can find detailed instructions by visiting
<https://icefall.readthedocs.io/en/latest/recipes/yesno.html>
It describes how to run this recipe and how to use
a pre-trained model with `./pretrained.py`.

View File

@ -0,0 +1,248 @@
# Copyright 2021 Piotr Żelasko
# 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import List
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
BucketingSampler,
CutConcatenate,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
class YesNoAsrDataModule(DataModule):
"""
DataModule for k2 ASR experiments.
It assumes there is always one train dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
"""
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
super().add_arguments(parser)
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--feature-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=30.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=False,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=10,
help="The number of buckets for the BucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
def train_dataloaders(self) -> DataLoader:
logging.info("About to get train cuts")
cuts_train = self.train_cuts()
logging.info("About to create train dataset")
transforms = []
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=23))
),
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using BucketingSampler.")
train_sampler = BucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
bucket_method="equal_duration",
drop_last=True,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
)
return train_dl
def test_dataloaders(self) -> DataLoader:
logging.info("About to get test cuts")
cuts_test = self.test_cuts()
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = SingleCutSampler(
cuts_test, max_duration=self.args.max_duration
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test, batch_size=None, sampler=sampler, num_workers=1
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest(self.args.feature_dir / "cuts_train.json.gz")
return cuts_train
@lru_cache()
def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts")
cuts_test = load_manifest(self.args.feature_dir / "cuts_test.json.gz")
return cuts_test

321
egs/yesno/ASR/tdnn/decode.py Executable file
View File

@ -0,0 +1,321 @@
#!/usr/bin/env python3
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import YesNoAsrDataModule
from model import Tdnn
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import get_lattice, one_best_decoding
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=14,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=2,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
tdnn/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("tdnn/exp/"),
"lang_dir": Path("data/lang_phone"),
"lm_dir": Path("data/lm"),
"feature_dim": 23,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
HLG: k2.Fsa,
batch: dict,
word_table: k2.SymbolTable,
) -> List[List[int]]:
"""Decode one batch and return the result in a list-of-list.
Each sub list contains the word IDs for an utterance in the batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding.
- params.method is "nbest", it uses nbest decoding.
model:
The neural model.
HLG:
The decoding graph.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
(https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py)
word_table:
It is the word symbol table.
Returns:
Return the decoding result. `len(ans)` == batch size.
"""
device = HLG.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is [N, T, C]
nnet_output = model(feature)
# nnet_output is [N, T, C]
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
lattice = get_lattice(
nnet_output=nnet_output,
HLG=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
)
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
HLG: k2.Fsa,
word_table: k2.SymbolTable,
) -> List[Tuple[List[int], List[int]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
HLG:
The decoding graph.
word_table:
It is word symbol table.
Returns:
Return a tuple contains two elements (ref_text, hyp_text):
The first is the reference transcript, and the second is the
predicted result.
"""
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = []
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps = decode_one_batch(
params=params,
model=model,
HLG=HLG,
batch=batch,
word_table=word_table,
)
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results.extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
exp_dir: Path,
test_set_name: str,
results: List[Tuple[List[int], List[int]]],
) -> None:
"""Save results to `exp_dir`.
Args:
exp_dir:
The output directory. This function create the following files inside
this directory:
- recogs-{test_set_name}.text
It contains the reference and hypothesis results, like below::
ref=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES']
hyp=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES']
ref=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES']
hyp=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES']
- errs-{test_set_name}.txt
It contains the detailed WER.
test_set_name:
The name of the test set, which will be part of the result filename.
results:
A list of tuples, each of which contains (ref_words, hyp_words).
Returns:
Return None.
"""
recog_path = exp_dir / f"recogs-{test_set_name}.txt"
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = exp_dir / f"errs-{test_set_name}.txt"
with open(errs_filename, "w") as f:
write_error_stats(f, f"{test_set_name}", results)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
@torch.no_grad()
def main():
parser = get_parser()
YesNoAsrDataModule.add_arguments(parser)
args = parser.parse_args()
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/log-decode")
logging.info("Decoding started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device)
assert HLG.requires_grad is False
model = Tdnn(
num_features=params.feature_dim,
num_classes=max_token_id + 1, # +1 for the blank symbol
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
model.to(device)
model.eval()
yes_no = YesNoAsrDataModule(args)
test_dl = yes_no.test_dataloaders()
results = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
word_table=lexicon.word_table,
)
save_results(
exp_dir=params.exp_dir, test_set_name="test_set", results=results
)
logging.info("Done!")
if __name__ == "__main__":
main()

81
egs/yesno/ASR/tdnn/model.py Executable file
View File

@ -0,0 +1,81 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corp. (author: Fangjun Kuang)
import torch
import torch.nn as nn
class Tdnn(nn.Module):
def __init__(self, num_features: int, num_classes: int):
"""
Args:
num_features:
Model input dimension.
num_classes:
Model output dimension
"""
super().__init__()
self.tdnn = nn.Sequential(
nn.Conv1d(
in_channels=num_features,
out_channels=32,
kernel_size=3,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=32, affine=False),
nn.Conv1d(
in_channels=32,
out_channels=32,
kernel_size=5,
dilation=2,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=32, affine=False),
nn.Conv1d(
in_channels=32,
out_channels=32,
kernel_size=5,
dilation=4,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=32, affine=False),
)
self.output_linear = nn.Linear(in_features=32, out_features=num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
The input tensor with shape [N, T, C]
Returns:
The output tensor has shape [N, T, C]
"""
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
x = self.tdnn(x)
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
x = self.output_linear(x)
x = nn.functional.log_softmax(x, dim=-1)
return x
def test_tdnn():
num_features = 23
num_classes = 4
model = Tdnn(num_features=num_features, num_classes=num_classes)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
N = 2
T = 100
C = num_features
x = torch.randn(N, T, C)
y = model(x)
print(x.shape)
print(y.shape)
if __name__ == "__main__":
test_tdnn()

209
egs/yesno/ASR/tdnn/pretrained.py Executable file
View File

@ -0,0 +1,209 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from model import Tdnn
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import get_lattice, one_best_decoding
from icefall.utils import AttributeDict, get_texts
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--words-file",
type=str,
required=True,
help="Path to words.txt",
)
parser.add_argument(
"--HLG", type=str, required=True, help="Path to HLG.pt."
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 23,
"num_classes": 4, # [<blk>, N, SIL, Y]
"sample_rate": 8000,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = Tdnn(
num_features=params.feature_dim,
num_classes=params.num_classes,
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
# Note: We don't use key padding mask for attention during decoding
with torch.no_grad():
nnet_output = model(features)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
lattice = get_lattice(
nnet_output=nnet_output,
HLG=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
)
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

584
egs/yesno/ASR/tdnn/train.py Executable file
View File

@ -0,0 +1,584 @@
#!/usr/bin/env python3
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional
import k2
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from asr_datamodule import YesNoAsrDataModule
from lhotse.utils import fix_random_seed
from model import Tdnn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=15,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
tdnn/exp/epoch-{start_epoch-1}.pt
""",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
is saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- exp_dir: It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
- lang_dir: It contains language related input files such as
"lexicon.txt"
- lr: It specifies the initial learning rate
- feature_dim: The model input dim. It has to match the one used
in computing features.
- weight_decay: The weight_decay for the optimizer.
- subsampling_factor: The subsampling factor for the model.
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
and continue training from that checkpoint.
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- valid_interval: Run validation if batch_idx % valid_interval` is 0
- beam_size: It is used in k2.ctc_loss
- reduction: It is used in k2.ctc_loss
- use_double_scores: It is used in k2.ctc_loss
"""
params = AttributeDict(
{
"exp_dir": Path("tdnn/exp"),
"lang_dir": Path("data/lang_phone"),
"lr": 1e-2,
"feature_dim": 23,
"weight_decay": 1e-6,
"start_epoch": 0,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 10,
"valid_interval": 10,
"beam_size": 10,
"reduction": "sum",
"use_double_scores": True,
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
model: nn.Module,
batch: dict,
graph_compiler: CtcTrainingGraphCompiler,
is_training: bool,
):
"""
Compute CTC loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Tdnn in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
graph_compiler:
It is used to build a decoding graph from a ctc topo and training
transcript. The training transcript is contained in the given `batch`,
while the ctc topo is built when this compiler is instantiated.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = graph_compiler.device
feature = batch["inputs"]
# at entry, feature is [N, T, C]
assert feature.ndim == 3
feature = feature.to(device)
with torch.set_grad_enabled(is_training):
nnet_output = model(feature)
# nnet_output is [N, T, C]
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervisions = batch["supervisions"]
texts = supervisions["text"]
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
decoding_graph = graph_compiler.compile(texts)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
)
loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
assert loss.requires_grad == is_training
# train_frames and valid_frames are used for printing.
if is_training:
params.train_frames = supervision_segments[:, 2].sum().item()
else:
params.valid_frames = supervision_segments[:, 2].sum().item()
return loss
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
graph_compiler: CtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> None:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = 0.0
tot_frames = 0.0
for batch_idx, batch in enumerate(valid_dl):
loss = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=False,
)
assert loss.requires_grad is False
loss_cpu = loss.detach().cpu().item()
tot_loss += loss_cpu
tot_frames += params.valid_frames
if world_size > 1:
s = torch.tensor([tot_loss, tot_frames], device=loss.device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_frames = s[1]
params.valid_loss = tot_loss / tot_frames
if params.valid_loss < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
graph_compiler: CtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
graph_compiler:
It is used to convert transcripts to FSAs.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = 0.0 # sum of losses over all batches
tot_frames = 0.0 # sum of frames over all batches
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
loss_cpu = loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f},"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/valid_loss",
params.valid_loss,
params.batch_idx_train,
)
params.train_loss = tot_loss / tot_frames
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
lexicon = Lexicon(params.lang_dir)
max_phone_id = max(lexicon.tokens)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
model = Tdnn(
num_features=params.feature_dim,
num_classes=max_phone_id + 1, # +1 for the blank symbol
)
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
optimizer = optim.SGD(
model.parameters(),
lr=params.lr,
weight_decay=params.weight_decay,
)
if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"])
yes_no = YesNoAsrDataModule(args)
train_dl = yes_no.train_dataloaders()
# There are only 60 waves: 30 files are used for training
# and the remaining 30 files are used for testing.
# We use test data as validation.
valid_dl = yes_no.test_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
graph_compiler=graph_compiler,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
scheduler=None,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
YesNoAsrDataModule.add_arguments(parser)
args = parser.parse_args()
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
if __name__ == "__main__":
main()

View File

@ -1,3 +1,20 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import List, Union
@ -57,7 +74,9 @@ class BpeCtcTrainingGraphCompiler(object):
return self.sp.encode(texts, out_type=int)
def compile(
self, piece_ids: List[List[int]], modified: bool = False,
self,
piece_ids: List[List[int]],
modified: bool = False,
) -> k2.Fsa:
"""Build a ctc graph from a list-of-list piece IDs.

View File

@ -1,3 +1,20 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

View File

@ -1,3 +1,20 @@
# Copyright 2021 Piotr Żelasko
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from typing import List, Union

View File

@ -1,3 +1,19 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, List, Optional, Tuple, Union
@ -6,8 +22,6 @@ import kaldialign
import torch
import torch.nn as nn
from icefall.lexicon import Lexicon
def _get_random_paths(
lattice: k2.Fsa,
@ -607,7 +621,7 @@ def nbest_oracle(
lattice: k2.Fsa,
num_paths: int,
ref_texts: List[str],
lexicon: Lexicon,
word_table: k2.SymbolTable,
scale: float = 1.0,
) -> Dict[str, List[List[int]]]:
"""Select the best hypothesis given a lattice and a reference transcript.
@ -628,8 +642,8 @@ def nbest_oracle(
ref_texts:
A list of reference transcript. Each entry contains space(s)
separated words
lexicon:
It is used to convert word IDs to word symbols.
word_table:
It is the word symbol table.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
@ -664,7 +678,7 @@ def nbest_oracle(
best_hyp_words = None
min_error = float("inf")
for hyp_words in hyps:
hyp_words = [lexicon.word_table[i] for i in hyp_words]
hyp_words = [word_table[i] for i in hyp_words]
this_error = kaldialign.edit_distance(ref_words, hyp_words)["total"]
if this_error < min_error:
min_error = this_error

View File

@ -1,3 +1,20 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch

View File

@ -1,3 +1,20 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import k2
@ -8,7 +25,10 @@ from icefall.lexicon import Lexicon
class CtcTrainingGraphCompiler(object):
def __init__(
self, lexicon: Lexicon, device: torch.device, oov: str = "<UNK>",
self,
lexicon: Lexicon,
device: torch.device,
oov: str = "<UNK>",
):
"""
Args:

View File

@ -1,3 +1,20 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
import sys

View File

@ -1,3 +1,20 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os

View File

@ -1,6 +1,20 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.lexicon import BpeLexicon
@ -15,7 +29,7 @@ def test():
compiler = BpeCtcTrainingGraphCompiler(lang_dir)
ids = compiler.texts_to_ids(["HELLO", "WORLD ZZZ"])
fsa = compiler.compile(ids)
compiler.compile(ids)
lexicon = BpeLexicon(lang_dir)
ids0 = lexicon.words_to_piece_ids(["HELLO"])

View File

@ -1,4 +1,20 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch

View File

@ -1,6 +1,20 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
import re

View File

@ -1,4 +1,20 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path

View File

@ -1,4 +1,21 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import k2
import pytest
import torch