Merge branch 'master' into docs

This commit is contained in:
pkufool 2021-09-12 16:33:00 +08:00
commit be2795cff2
41 changed files with 534 additions and 183 deletions

View File

@ -5,7 +5,6 @@ max-line-length = 80
per-file-ignores = per-file-ignores =
# line too long # line too long
egs/librispeech/ASR/conformer_ctc/conformer.py: E501, egs/librispeech/ASR/conformer_ctc/conformer.py: E501,
egs/librispeech/ASR/conformer_ctc/decode.py: E501,
exclude = exclude =
.git, .git,

View File

@ -56,7 +56,7 @@ jobs:
run: | run: |
python3 -m pip install --upgrade pip black flake8 python3 -m pip install --upgrade pip black flake8
python3 -m pip install -U pip python3 -m pip install -U pip
python3 -m pip install k2==1.4.dev20210822+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/ python3 -m pip install k2==1.7.dev20210908+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/
python3 -m pip install torchaudio==0.7.2 python3 -m pip install torchaudio==0.7.2
python3 -m pip install git+https://github.com/lhotse-speech/lhotse python3 -m pip install git+https://github.com/lhotse-speech/lhotse

View File

@ -45,7 +45,7 @@ jobs:
- name: Install Python dependencies - name: Install Python dependencies
run: | run: |
python3 -m pip install --upgrade pip black flake8 python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2
- name: Run flake8 - name: Run flake8
shell: bash shell: bash

View File

@ -32,7 +32,8 @@ jobs:
os: [ubuntu-18.04, macos-10.15] os: [ubuntu-18.04, macos-10.15]
python-version: [3.6, 3.7, 3.8, 3.9] python-version: [3.6, 3.7, 3.8, 3.9]
torch: ["1.8.1"] torch: ["1.8.1"]
k2-version: ["1.4.dev20210822"] k2-version: ["1.7.dev20210908"]
fail-fast: false fail-fast: false
steps: steps:

2
.gitignore vendored
View File

@ -4,4 +4,4 @@ path.sh
exp exp
exp*/ exp*/
*.pt *.pt
download/ download

View File

@ -1 +1,2 @@
sphinx_rtd_theme sphinx_rtd_theme
sphinx

View File

@ -16,7 +16,6 @@
import sphinx_rtd_theme import sphinx_rtd_theme
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
project = "icefall" project = "icefall"

View File

@ -0,0 +1,67 @@
.. _follow the code style:
Follow the code style
=====================
We use the following tools to make the code style to be as consistent as possible:
- `black <https://github.com/psf/black>`_, to format the code
- `flake8 <https://github.com/PyCQA/flake8>`_, to check the style and quality of the code
- `isort <https://github.com/PyCQA/isort>`_, to sort ``imports``
The following versions of the above tools are used:
- ``black == 12.6b0``
- ``flake8 == 3.9.2``
- ``isort == 5.9.2``
After running the following commands:
.. code-block::
$ git clone https://github.com/k2-fsa/icefall
$ cd icefall
$ pip install pre-commit
$ pre-commit install
it will run the following checks whenever you run ``git commit``, **automatically**:
.. figure:: images/pre-commit-check.png
:width: 600
:align: center
pre-commit hooks invoked by ``git commit`` (Failed).
If any of the above checks failed, your ``git commit`` was not successful.
Please fix any issues reported by the check tools.
.. HINT::
Some of the check tools, i.e., ``black`` and ``isort`` will modify
the files to be commited **in-place**. So please run ``git status``
after failure to see which file has been modified by the tools
before you make any further changes.
After fixing all the failures, run ``git commit`` again and
it should succeed this time:
.. figure:: images/pre-commit-check-success.png
:width: 600
:align: center
pre-commit hooks invoked by ``git commit`` (Succeeded).
If you want to check the style of your code before ``git commit``, you
can do the following:
.. code-block:: bash
$ cd icefall
$ pip install black==21.6b0 flake8==3.9.2 isort==5.9.2
$ black --check your_changed_file.py
$ black your_changed_file.py # modify it in-place
$
$ flake8 your_changed_file.py
$
$ isort --check your_changed_file.py # modify it in-place
$ isort your_changed_file.py

View File

@ -0,0 +1,45 @@
Contributing to Documentation
=============================
We use `sphinx <https://www.sphinx-doc.org/en/master/>`_
for documentation.
Before writing documentation, you have to prepare the environment:
.. code-block:: bash
$ cd docs
$ pip install -r requirements.txt
After setting up the environment, you are ready to write documentation.
Please refer to `reStructuredText Primer <https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html>`_
if you are not familiar with ``reStructuredText``.
After writing some documentation, you can build the documentation **locally**
to preview what it looks like if it is published:
.. code-block:: bash
$ cd docs
$ make html
The generated documentation is in ``docs/build/html`` and can be viewed
with the following commands:
.. code-block:: bash
$ cd docs/build/html
$ python3 -m http.server
It will print::
Serving HTTP on 0.0.0.0 port 8000 (http://0.0.0.0:8000/) ...
Open your browser, go to `<http://0.0.0.0:8000/>`_, and you will see
the following:
.. figure:: images/doc-contrib.png
:width: 600
:align: center
View generated documentation locally with ``python3 -m http.server``.

View File

@ -0,0 +1,156 @@
How to create a recipe
======================
.. HINT::
Please read :ref:`follow the code style` to adjust your code sytle.
.. CAUTION::
``icefall`` is designed to be as Pythonic as possible. Please use
Python in your recipe if possible.
Data Preparation
----------------
We recommend you to prepare your training/test/validate dataset
with `lhotse <https://github.com/lhotse-speech/lhotse>`_.
Please refer to `<https://lhotse.readthedocs.io/en/latest/index.html>`_
for how to create a recipe in ``lhotse``.
.. HINT::
The ``yesno`` recipe in ``lhotse`` is a very good example.
Please refer to `<https://github.com/lhotse-speech/lhotse/pull/380>`_,
which shows how to add a new recipe to ``lhotse``.
Suppose you would like to add a recipe for a dataset named ``foo``.
You can do the following:
.. code-block::
$ cd egs
$ mkdir -p foo/ASR
$ cd foo/ASR
$ touch prepare.sh
$ chmod +x prepare.sh
If your dataset is very simple, please follow
`egs/yesno/ASR/prepare.sh <https://github.com/k2-fsa/icefall/blob/master/egs/yesno/ASR/prepare.sh>`_
to write your own ``prepare.sh``.
Otherwise, please refer to
`egs/librispeech/ASR/prepare.sh <https://github.com/k2-fsa/icefall/blob/master/egs/yesno/ASR/prepare.sh>`_
to prepare your data.
Training
--------
Assume you have a fancy model, called ``bar`` for the ``foo`` recipe, you can
organize your files in the following way:
.. code-block::
$ cd egs/foo/ASR
$ mkdir bar
$ cd bar
$ touch README.md model.py train.py decode.py asr_datamodule.py pretrained.py
For instance , the ``yesno`` recipe has a ``tdnn`` model and its directory structure
looks like the following:
.. code-block:: bash
egs/yesno/ASR/tdnn/
|-- README.md
|-- asr_datamodule.py
|-- decode.py
|-- model.py
|-- pretrained.py
`-- train.py
**File description**:
- ``README.md``
It contains information of this recipe, e.g., how to run it, what the WER is, etc.
- ``asr_datamodule.py``
It provides code to create PyTorch dataloaders with train/test/validation dataset.
- ``decode.py``
It takes as inputs the checkpoints saved during the training stage to decode the test
dataset(s).
- ``model.py``
It contains the definition of your fancy neural network model.
- ``pretrained.py``
We can use this script to do inference with a pre-trained model.
- ``train.py``
It contains training code.
.. HINT::
Please take a look at
- `egs/yesno/tdnn <https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn>`_
- `egs/librispeech/tdnn_lstm_ctc <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/tdnn_lstm_ctc>`_
- `egs/librispeech/conformer_ctc <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conformer_ctc>`_
to get a feel what the resulting files look like.
.. NOTE::
Every model in a recipe is kept to be as self-contained as possible.
We tolerate duplicate code among different recipes.
The training stage should be invocable by:
.. code-block::
$ cd egs/foo/ASR
$ ./bar/train.py
$ ./bar/train.py --help
Decoding
--------
Please refer to
- `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/conformer_ctc/decode.py>`_
If your model is transformer/conformer based.
- `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py>`_
If your model is TDNN/LSTM based, i.e., there is no attention decoder.
- `<https://github.com/k2-fsa/icefall/blob/master/egs/yesno/ASR/tdnn/decode.py>`_
If there is no LM rescoring.
The decoding stage should be invocable by:
.. code-block::
$ cd egs/foo/ASR
$ ./bar/decode.py
$ ./bar/decode.py --help
Pre-trained model
-----------------
Please demonstrate how to use your model for inference in ``egs/foo/ASR/bar/pretrained.py``.
If possible, please consider creating a Colab notebook to show that.

Binary file not shown.

After

Width:  |  Height:  |  Size: 198 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 153 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 214 KiB

View File

@ -0,0 +1,22 @@
Contributing
============
Contributions to ``icefall`` are very welcomed.
There are many possible ways to make contributions and
two of them are:
- To write documentation
- To write code
- (1) To follow the code style in the repository
- (2) To write a new recipe
In this page, we describe how to contribute documentation
and code to ``icefall``.
.. toctree::
:maxdepth: 2
doc
code-style
how-to-create-a-recipe

View File

@ -22,3 +22,4 @@ speech recognition recipes using `k2 <https://github.com/k2-fsa/k2>`_.
installation/index installation/index
recipes/index recipes/index
contributing/index

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="80" height="20" role="img" aria-label="k2: &gt;= v1.7"><title>k2: &gt;= v1.7</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="80" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="23" height="20" fill="#555"/><rect x="23" width="57" height="20" fill="blueviolet"/><rect width="80" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="125" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="130">k2</text><text x="125" y="140" transform="scale(.1)" fill="#fff" textLength="130">k2</text><text aria-hidden="true" x="505" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="470">&gt;= v1.7</text><text x="505" y="140" transform="scale(.1)" fill="#fff" textLength="470">&gt;= v1.7</text></g></svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -7,6 +7,7 @@ Installation
- |device| - |device|
- |python_versions| - |python_versions|
- |torch_versions| - |torch_versions|
- |k2_versions|
.. |os| image:: ./images/os-Linux_macOS-ff69b4.svg .. |os| image:: ./images/os-Linux_macOS-ff69b4.svg
:alt: Supported operating systems :alt: Supported operating systems
@ -20,7 +21,10 @@ Installation
.. |torch_versions| image:: ./images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg .. |torch_versions| image:: ./images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg
:alt: Supported PyTorch versions :alt: Supported PyTorch versions
icefall depends on `k2 <https://github.com/k2-fsa/k2>`_ and .. |k2_versions| image:: ./images/k2-v-1.7.svg
:alt: Supported k2 versions
``icefall`` depends on `k2 <https://github.com/k2-fsa/k2>`_ and
`lhotse <https://github.com/lhotse-speech/lhotse>`_. `lhotse <https://github.com/lhotse-speech/lhotse>`_.
We recommend you to install ``k2`` first, as ``k2`` is bound to We recommend you to install ``k2`` first, as ``k2`` is bound to
@ -32,12 +36,16 @@ installs its dependency PyTorch, which can be reused by ``lhotse``.
-------------- --------------
Please refer to `<https://k2.readthedocs.io/en/latest/installation/index.html>`_ Please refer to `<https://k2.readthedocs.io/en/latest/installation/index.html>`_
to install `k2`. to install ``k2``.
.. CAUTION::
You need to install ``k2`` with a version at least **v1.7**.
.. HINT:: .. HINT::
If you have already installed PyTorch and don't want to replace it, If you have already installed PyTorch and don't want to replace it,
please install a version of k2 that is compiled against the version please install a version of ``k2`` that is compiled against the version
of PyTorch you are using. of PyTorch you are using.
(2) Install lhotse (2) Install lhotse
@ -50,10 +58,15 @@ to install ``lhotse``.
Install ``lhotse`` also installs its dependency `torchaudio <https://github.com/pytorch/audio>`_. Install ``lhotse`` also installs its dependency `torchaudio <https://github.com/pytorch/audio>`_.
.. CAUTION::
If you have installed ``torchaudio``, please consider uninstalling it before
installing ``lhotse``. Otherwise, it may update your already installed PyTorch.
(3) Download icefall (3) Download icefall
-------------------- --------------------
icefall is a collection of Python scripts, so you don't need to install it ``icefall`` is a collection of Python scripts, so you don't need to install it
and we don't provide a ``setup.py`` to install it. and we don't provide a ``setup.py`` to install it.
What you need is to download it and set the environment variable ``PYTHONPATH`` What you need is to download it and set the environment variable ``PYTHONPATH``
@ -202,22 +215,6 @@ The following shows an example about setting up the environment.
valtree-3.1.0 lhotse-0.8.0.dev-2a1410b-clean lilcom-1.1.1 numpy-1.21.2 packaging-21.0 pycparser-2.20 pyparsing-2.4.7 pyyaml-5.4.1 sor valtree-3.1.0 lhotse-0.8.0.dev-2a1410b-clean lilcom-1.1.1 numpy-1.21.2 packaging-21.0 pycparser-2.20 pyparsing-2.4.7 pyyaml-5.4.1 sor
tedcontainers-2.4.0 toolz-0.11.1 torchaudio-0.9.0 tqdm-4.62.1 tedcontainers-2.4.0 toolz-0.11.1 torchaudio-0.9.0 tqdm-4.62.1
**NOTE**: After installing ``lhotse``, you will encounter the following error:
.. code-block::
$ lhotse download --help
-bash: /ceph-fj/fangjun/test-icefall/bin/lhotse: python: bad interpreter: No such file or directory
The correct fix is:
.. code-block::
echo '#!/usr/bin/env python3' | cat - $(which lhotse) > /tmp/lhotse-bin
chmod +x /tmp/lhotse-bin
mv /tmp/lhotse-bin $(which lhotse)
(5) Download icefall (5) Download icefall
~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~
@ -383,7 +380,7 @@ Now let us run the training part:
.. CAUTION:: .. CAUTION::
We use ``export CUDA_VISIBLE_DEVICES=""`` so that icefall uses CPU We use ``export CUDA_VISIBLE_DEVICES=""`` so that ``icefall`` uses CPU
even if there are GPUs available. even if there are GPUs available.
The training log is given below: The training log is given below:

View File

@ -15,4 +15,3 @@ We may add recipes for other tasks as well in the future.
yesno yesno
librispeech librispeech

View File

@ -303,7 +303,7 @@ The commonly used options are:
- ``--lattice-score-scale`` - ``--lattice-score-scale``
It is used to scaled down lattice scores so that we can more unique It is used to scale down lattice scores so that there are more unique
paths for rescoring. paths for rescoring.
- ``--max-duration`` - ``--max-duration``
@ -314,7 +314,7 @@ The commonly used options are:
Pre-trained Model Pre-trained Model
----------------- -----------------
We have uploaded the pre-trained model to We have uploaded a pre-trained model to
`<https://huggingface.co/pkufool/icefall_asr_librispeech_conformer_ctc>`_. `<https://huggingface.co/pkufool/icefall_asr_librispeech_conformer_ctc>`_.
We describe how to use the pre-trained model to transcribe a sound file or We describe how to use the pre-trained model to transcribe a sound file or
@ -324,7 +324,7 @@ Install kaldifeat
~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~
`kaldifeat <https://github.com/csukuangfj/kaldifeat>`_ is used to `kaldifeat <https://github.com/csukuangfj/kaldifeat>`_ is used to
extract features for a single sound file or multiple soundfiles extract features for a single sound file or multiple sound files
at the same time. at the same time.
Please refer to `<https://github.com/csukuangfj/kaldifeat>`_ for installation. Please refer to `<https://github.com/csukuangfj/kaldifeat>`_ for installation.
@ -397,7 +397,7 @@ After downloading, you will have the following files:
- ``data/lm/G_4_gram.pt`` - ``data/lm/G_4_gram.pt``
It is a 4-gram LM, useful for LM rescoring. It is a 4-gram LM, used for n-gram LM rescoring.
- ``exp/pretrained.pt`` - ``exp/pretrained.pt``

View File

@ -21,6 +21,32 @@ To get more unique paths, we scaled the lattice.scores with 0.5 (see https://git
|test-clean|1.3|1.2| |test-clean|1.3|1.2|
|test-other|1.2|1.1| |test-other|1.2|1.1|
You can use the following commands to reproduce our results:
```bash
git clone https://github.com/k2-fsa/icefall
cd icefall
# It was using ef233486, you may not need to switch to it
# git checkout ef233486
cd egs/librispeech/ASR
./prepare.sh
export CUDA_VISIBLE_DEVICES="0,1,2,3"
python conformer_ctc/train.py --bucketing-sampler True \
--concatenate-cuts False \
--max-duration 200 \
--full-libri True \
--world-size 4
python conformer_ctc/decode.py --lattice-score-scale 0.5 \
--epoch 34 \
--avg 20 \
--method attention-decoder \
--max-duration 20 \
--num-paths 100
```
### LibriSpeech training results (Tdnn-Lstm) ### LibriSpeech training results (Tdnn-Lstm)
#### 2021-08-24 #### 2021-08-24
@ -43,4 +69,3 @@ We searched the lm_score_scale for best results, the scales that produced the WE
|--|--| |--|--|
|test-clean|0.8| |test-clean|0.8|
|test-other|0.9| |test-other|0.9|

View File

@ -45,6 +45,7 @@ from icefall.utils import (
get_texts, get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool,
write_error_stats, write_error_stats,
) )
@ -78,16 +79,16 @@ def get_parser():
Supported values are: Supported values are:
- (1) 1best. Extract the best path from the decoding lattice as the - (1) 1best. Extract the best path from the decoding lattice as the
decoding result. decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path with - (2) nbest. Extract n paths from the decoding lattice; the path
the highest score is the decoding result. with the highest score is the decoding result.
- (3) nbest-rescoring. Extract n paths from the decoding lattice, - (3) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result. the highest score is the decoding result.
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an n-gram LM - (4) whole-lattice-rescoring. Rescore the decoding lattice with an
(e.g., a 4-gram LM), the best path of rescored lattice is the n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
decoding result. is the decoding result.
- (5) attention-decoder. Extract n paths from the LM rescored lattice, - (5) attention-decoder. Extract n paths from the LM rescored
the path with the highest score is the decoding result. lattice, the path with the highest score is the decoding result.
- (6) nbest-oracle. Its WER is the lower bound of any n-best - (6) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best rescoring method can achieve. Useful for debugging n-best
rescoring method. rescoring method.
@ -116,6 +117,17 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
return parser return parser
@ -541,6 +553,13 @@ def main():
logging.info(f"averaging {filenames}") logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames)) model.load_state_dict(average_checkpoints(filenames))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return
model.to(device) model.to(device)
model.eval() model.eval()
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])

View File

@ -16,9 +16,8 @@
# limitations under the License. # limitations under the License.
from subsampling import Conv2dSubsampling
from subsampling import VggSubsampling
import torch import torch
from subsampling import Conv2dSubsampling, VggSubsampling
def test_conv2d_subsampling(): def test_conv2d_subsampling():

View File

@ -17,17 +17,16 @@
import torch import torch
from torch.nn.utils.rnn import pad_sequence
from transformer import ( from transformer import (
Transformer, Transformer,
add_eos,
add_sos,
decoder_padding_mask,
encoder_padding_mask, encoder_padding_mask,
generate_square_subsequent_mask, generate_square_subsequent_mask,
decoder_padding_mask,
add_sos,
add_eos,
) )
from torch.nn.utils.rnn import pad_sequence
def test_encoder_padding_mask(): def test_encoder_padding_mask():
supervisions = { supervisions = {

View File

@ -102,14 +102,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
LG.labels[LG.labels >= first_token_disambig_id] = 0 LG.labels[LG.labels >= first_token_disambig_id] = 0
assert isinstance(LG.aux_labels, k2.RaggedInt) assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0 LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG) LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG) LG = k2.connect(LG)
LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG") logging.info("Arc sorting LG")
LG = k2.arc_sort(LG) LG = k2.arc_sort(LG)

View File

@ -82,14 +82,14 @@ class LibriSpeechAsrDataModule(DataModule):
group.add_argument( group.add_argument(
"--max-duration", "--max-duration",
type=int, type=int,
default=500.0, default=200.0,
help="Maximum pooled recordings duration (seconds) in a " help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.", "single batch. You can reduce it if it causes CUDA OOM.",
) )
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
default=False, default=True,
help="When enabled, the batches will come from buckets of " help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).", "similar duration (saves padding frames).",
) )

View File

@ -42,8 +42,8 @@ from icefall.utils import (
get_texts, get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
write_error_stats,
str2bool, str2bool,
write_error_stats,
) )
@ -98,9 +98,11 @@ def get_params() -> AttributeDict:
# - nbest # - nbest
# - nbest-rescoring # - nbest-rescoring
# - whole-lattice-rescoring # - whole-lattice-rescoring
"method": "1best", "method": "whole-lattice-rescoring",
# "method": "1best",
# "method": "nbest",
# num_paths is used when method is "nbest" and "nbest-rescoring" # num_paths is used when method is "nbest" and "nbest-rescoring"
"num_paths": 30, "num_paths": 100,
} }
) )
return params return params
@ -424,6 +426,7 @@ def main():
torch.save( torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
) )
return
model.to(device) model.to(device)
model.eval() model.eval()

0
egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py Normal file → Executable file
View File

View File

@ -10,5 +10,5 @@ get the following WER:
``` ```
Please refer to Please refer to
<https://icefal1.readthedocs.io/en/latest/recipes/yesno.html> <https://icefall.readthedocs.io/en/latest/recipes/yesno.html>
for detailed instructions. for detailed instructions.

View File

@ -80,14 +80,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
LG.labels[LG.labels >= first_token_disambig_id] = 0 LG.labels[LG.labels >= first_token_disambig_id] = 0
assert isinstance(LG.aux_labels, k2.RaggedInt) assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0 LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG) LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG) LG = k2.connect(LG)
LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG") logging.info("Arc sorting LG")
LG = k2.arc_sort(LG) LG = k2.arc_sort(LG)

View File

@ -2,7 +2,7 @@
## How to run this recipe ## How to run this recipe
You can find detailed instructions by visiting You can find detailed instructions by visiting
<https://icefal1.readthedocs.io/en/latest/recipes/yesno.html> <https://icefall.readthedocs.io/en/latest/recipes/yesno.html>
It describes how to run this recipe and how to use It describes how to run this recipe and how to use
a pre-trained model with `./pretrained.py`. a pre-trained model with `./pretrained.py`.

View File

@ -296,6 +296,7 @@ def main():
torch.save( torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
) )
return
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -84,8 +84,8 @@ def _intersect_device(
for start, end in splits: for start, end in splits:
indexes = torch.arange(start, end).to(b_to_a_map) indexes = torch.arange(start, end).to(b_to_a_map)
fsas = k2.index(b_fsas, indexes) fsas = k2.index_fsa(b_fsas, indexes)
b_to_a = k2.index(b_to_a_map, indexes) b_to_a = k2.index_select(b_to_a_map, indexes)
path_lattice = k2.intersect_device( path_lattice = k2.intersect_device(
a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a
) )
@ -215,18 +215,16 @@ def nbest_decoding(
scale=scale, scale=scale,
) )
# word_seq is a k2.RaggedInt sharing the same shape as `path` # word_seq is a k2.RaggedTensor sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s. # but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1. # The last entry in each sublist is -1.
word_seq = k2.index(lattice.aux_labels, path) if isinstance(lattice.aux_labels, torch.Tensor):
# Note: the above operation supports also the case when word_seq = k2.ragged.index(lattice.aux_labels, path)
# lattice.aux_labels is a ragged tensor. In that case, else:
# `remove_axis=True` is used inside the pybind11 binding code, word_seq = lattice.aux_labels.index(path, remove_axis=True)
# so the resulting `word_seq` still has 3 axes, like `path`.
# The 3 axes are [seq][path][word_id]
# Remove 0 (epsilon) and -1 from word_seq # Remove 0 (epsilon) and -1 from word_seq
word_seq = k2.ragged.remove_values_leq(word_seq, 0) word_seq = word_seq.remove_values_leq(0)
# Remove sequences with identical word sequences. # Remove sequences with identical word sequences.
# #
@ -234,12 +232,12 @@ def nbest_decoding(
# `new2old` is a 1-D torch.Tensor mapping from the output path index # `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index. # to the input path index.
# new2old.numel() == unique_word_seqs.tot_size(1) # new2old.numel() == unique_word_seqs.tot_size(1)
unique_word_seq, _, new2old = k2.ragged.unique_sequences( unique_word_seq, _, new2old = word_seq.unique(
word_seq, need_num_repeats=False, need_new2old_indexes=True need_num_repeats=False, need_new2old_indexes=True
) )
# Note: unique_word_seq still has the same axes as word_seq # Note: unique_word_seq still has the same axes as word_seq
seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) seq_to_path_shape = unique_word_seq.shape.get_layer(0)
# path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map is a 1-D torch.Tensor.
# path_to_seq_map[i] is the seq to which the i-th path belongs # path_to_seq_map[i] is the seq to which the i-th path belongs
@ -247,7 +245,7 @@ def nbest_decoding(
# Remove the seq axis. # Remove the seq axis.
# Now unique_word_seq has only two axes [path][word] # Now unique_word_seq has only two axes [path][word]
unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) unique_word_seq = unique_word_seq.remove_axis(0)
# word_fsa is an FsaVec with axes [path][state][arc] # word_fsa is an FsaVec with axes [path][state][arc]
word_fsa = k2.linear_fsa(unique_word_seq) word_fsa = k2.linear_fsa(unique_word_seq)
@ -275,35 +273,35 @@ def nbest_decoding(
use_double_scores=use_double_scores, log_semiring=False use_double_scores=use_double_scores, log_semiring=False
) )
# RaggedFloat currently supports float32 only. ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores)
# If Ragged<double> is wrapped, we can use k2.RaggedDouble here
ragged_tot_scores = k2.RaggedFloat(
seq_to_path_shape, tot_scores.to(torch.float32)
)
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) argmax_indexes = ragged_tot_scores.argmax()
# Since we invoked `k2.ragged.unique_sequences`, which reorders # Since we invoked `k2.ragged.unique_sequences`, which reorders
# the index from `path`, we use `new2old` here to convert argmax_indexes # the index from `path`, we use `new2old` here to convert argmax_indexes
# to the indexes into `path`. # to the indexes into `path`.
# #
# Use k2.index here since argmax_indexes' dtype is torch.int32 # Use k2.index here since argmax_indexes' dtype is torch.int32
best_path_indexes = k2.index(new2old, argmax_indexes) best_path_indexes = k2.index_select(new2old, argmax_indexes)
path_2axes = k2.ragged.remove_axis(path, 0) path_2axes = path.remove_axis(0)
# best_path is a k2.RaggedInt with 2 axes [path][arc_pos] # best_path is a k2.RaggedTensor with 2 axes [path][arc_pos]
best_path = k2.index(path_2axes, best_path_indexes) best_path, _ = path_2axes.index(
indexes=best_path_indexes, axis=0, need_value_indexes=False
)
# labels is a k2.RaggedInt with 2 axes [path][token_id] # labels is a k2.RaggedTensor with 2 axes [path][token_id]
# Note that it contains -1s. # Note that it contains -1s.
labels = k2.index(lattice.labels.contiguous(), best_path) labels = k2.ragged.index(lattice.labels.contiguous(), best_path)
labels = k2.ragged.remove_values_eq(labels, -1) labels = labels.remove_values_eq(-1)
# lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so # lattice.aux_labels is a k2.RaggedTensor with 2 axes, so
# aux_labels is also a k2.RaggedInt with 2 axes # aux_labels is also a k2.RaggedTensor with 2 axes
aux_labels = k2.index(lattice.aux_labels, best_path.values()) aux_labels, _ = lattice.aux_labels.index(
indexes=best_path.data, axis=0, need_value_indexes=False
)
best_path_fsa = k2.linear_fsa(labels) best_path_fsa = k2.linear_fsa(labels)
best_path_fsa.aux_labels = aux_labels best_path_fsa.aux_labels = aux_labels
@ -426,33 +424,36 @@ def rescore_with_n_best_list(
scale=scale, scale=scale,
) )
# word_seq is a k2.RaggedInt sharing the same shape as `path` # word_seq is a k2.RaggedTensor sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s. # but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1. # The last entry in each sublist is -1.
word_seq = k2.index(lattice.aux_labels, path) if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path, remove_axis=True)
# Remove epsilons and -1 from word_seq # Remove epsilons and -1 from word_seq
word_seq = k2.ragged.remove_values_leq(word_seq, 0) word_seq = word_seq.remove_values_leq(0)
# Remove paths that has identical word sequences. # Remove paths that has identical word sequences.
# #
# unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word] # unique_word_seq is still a k2.RaggedTensor with 3 axes [seq][path][word]
# except that there are no repeated paths with the same word_seq # except that there are no repeated paths with the same word_seq
# within a sequence. # within a sequence.
# #
# num_repeats is also a k2.RaggedInt with 2 axes containing the # num_repeats is also a k2.RaggedTensor with 2 axes containing the
# multiplicities of each path. # multiplicities of each path.
# num_repeats.num_elements() == unique_word_seqs.tot_size(1) # num_repeats.numel() == unique_word_seqs.tot_size(1)
# #
# Since k2.ragged.unique_sequences will reorder paths within a seq, # Since k2.ragged.unique_sequences will reorder paths within a seq,
# `new2old` is a 1-D torch.Tensor mapping from the output path index # `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index. # to the input path index.
# new2old.numel() == unique_word_seqs.tot_size(1) # new2old.numel() == unique_word_seqs.tot_size(1)
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences( unique_word_seq, num_repeats, new2old = word_seq.unique(
word_seq, need_num_repeats=True, need_new2old_indexes=True need_num_repeats=True, need_new2old_indexes=True
) )
seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) seq_to_path_shape = unique_word_seq.shape.get_layer(0)
# path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map is a 1-D torch.Tensor.
# path_to_seq_map[i] is the seq to which the i-th path # path_to_seq_map[i] is the seq to which the i-th path
@ -461,7 +462,7 @@ def rescore_with_n_best_list(
# Remove the seq axis. # Remove the seq axis.
# Now unique_word_seq has only two axes [path][word] # Now unique_word_seq has only two axes [path][word]
unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) unique_word_seq = unique_word_seq.remove_axis(0)
# word_fsa is an FsaVec with axes [path][state][arc] # word_fsa is an FsaVec with axes [path][state][arc]
word_fsa = k2.linear_fsa(unique_word_seq) word_fsa = k2.linear_fsa(unique_word_seq)
@ -485,39 +486,42 @@ def rescore_with_n_best_list(
use_double_scores=True, log_semiring=False use_double_scores=True, log_semiring=False
) )
path_2axes = k2.ragged.remove_axis(path, 0) path_2axes = path.remove_axis(0)
ans = dict() ans = dict()
for lm_scale in lm_scale_list: for lm_scale in lm_scale_list:
tot_scores = am_scores / lm_scale + lm_scores tot_scores = am_scores / lm_scale + lm_scores
# Remember that we used `k2.ragged.unique_sequences` to remove repeated # Remember that we used `k2.RaggedTensor.unique` to remove repeated
# paths to avoid redundant computation in `k2.intersect_device`. # paths to avoid redundant computation in `k2.intersect_device`.
# Now we use `num_repeats` to correct the scores for each path. # Now we use `num_repeats` to correct the scores for each path.
# #
# NOTE(fangjun): It is commented out as it leads to a worse WER # NOTE(fangjun): It is commented out as it leads to a worse WER
# tot_scores = tot_scores * num_repeats.values() # tot_scores = tot_scores * num_repeats.values()
ragged_tot_scores = k2.RaggedFloat( ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores)
seq_to_path_shape, tot_scores.to(torch.float32) argmax_indexes = ragged_tot_scores.argmax()
)
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
# Use k2.index here since argmax_indexes' dtype is torch.int32 # Use k2.index here since argmax_indexes' dtype is torch.int32
best_path_indexes = k2.index(new2old, argmax_indexes) best_path_indexes = k2.index_select(new2old, argmax_indexes)
# best_path is a k2.RaggedInt with 2 axes [path][arc_pos] # best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
best_path = k2.index(path_2axes, best_path_indexes) best_path, _ = path_2axes.index(
indexes=best_path_indexes, axis=0, need_value_indexes=False
)
# labels is a k2.RaggedInt with 2 axes [path][phone_id] # labels is a k2.RaggedTensor with 2 axes [path][phone_id]
# Note that it contains -1s. # Note that it contains -1s.
labels = k2.index(lattice.labels.contiguous(), best_path) labels = k2.ragged.index(lattice.labels.contiguous(), best_path)
labels = k2.ragged.remove_values_eq(labels, -1) labels = labels.remove_values_eq(-1)
# lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so # lattice.aux_labels is a k2.RaggedTensor tensor with 2 axes, so
# aux_labels is also a k2.RaggedInt with 2 axes # aux_labels is also a k2.RaggedTensor with 2 axes
aux_labels = k2.index(lattice.aux_labels, best_path.values())
aux_labels, _ = lattice.aux_labels.index(
indexes=best_path.data, axis=0, need_value_indexes=False
)
best_path_fsa = k2.linear_fsa(labels) best_path_fsa = k2.linear_fsa(labels)
best_path_fsa.aux_labels = aux_labels best_path_fsa.aux_labels = aux_labels
@ -659,12 +663,16 @@ def nbest_oracle(
scale=scale, scale=scale,
) )
word_seq = k2.index(lattice.aux_labels, path) if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.remove_values_leq(word_seq, 0) word_seq = k2.ragged.index(lattice.aux_labels, path)
unique_word_seq, _, _ = k2.ragged.unique_sequences( else:
word_seq, need_num_repeats=False, need_new2old_indexes=False word_seq = lattice.aux_labels.index(path, remove_axis=True)
word_seq = word_seq.remove_values_leq(0)
unique_word_seq, _, _ = word_seq.unique(
need_num_repeats=False, need_new2old_indexes=False
) )
unique_word_ids = k2.ragged.to_list(unique_word_seq) unique_word_ids = unique_word_seq.tolist()
assert len(unique_word_ids) == len(ref_texts) assert len(unique_word_ids) == len(ref_texts)
# unique_word_ids[i] contains all hypotheses of the i-th utterance # unique_word_ids[i] contains all hypotheses of the i-th utterance
@ -743,33 +751,36 @@ def rescore_with_attention_decoder(
scale=scale, scale=scale,
) )
# word_seq is a k2.RaggedInt sharing the same shape as `path` # word_seq is a k2.RaggedTensor sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s. # but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1. # The last entry in each sublist is -1.
word_seq = k2.index(lattice.aux_labels, path) if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path, remove_axis=True)
# Remove epsilons and -1 from word_seq # Remove epsilons and -1 from word_seq
word_seq = k2.ragged.remove_values_leq(word_seq, 0) word_seq = word_seq.remove_values_leq(0)
# Remove paths that has identical word sequences. # Remove paths that has identical word sequences.
# #
# unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word] # unique_word_seq is still a k2.RaggedTensor with 3 axes [seq][path][word]
# except that there are no repeated paths with the same word_seq # except that there are no repeated paths with the same word_seq
# within a sequence. # within a sequence.
# #
# num_repeats is also a k2.RaggedInt with 2 axes containing the # num_repeats is also a k2.RaggedTensor with 2 axes containing the
# multiplicities of each path. # multiplicities of each path.
# num_repeats.num_elements() == unique_word_seqs.tot_size(1) # num_repeats.numel() == unique_word_seqs.tot_size(1)
# #
# Since k2.ragged.unique_sequences will reorder paths within a seq, # Since k2.ragged.unique_sequences will reorder paths within a seq,
# `new2old` is a 1-D torch.Tensor mapping from the output path index # `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index. # to the input path index.
# new2old.numel() == unique_word_seq.tot_size(1) # new2old.numel() == unique_word_seq.tot_size(1)
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences( unique_word_seq, num_repeats, new2old = word_seq.unique(
word_seq, need_num_repeats=True, need_new2old_indexes=True need_num_repeats=True, need_new2old_indexes=True
) )
seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) seq_to_path_shape = unique_word_seq.shape.get_layer(0)
# path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map is a 1-D torch.Tensor.
# path_to_seq_map[i] is the seq to which the i-th path # path_to_seq_map[i] is the seq to which the i-th path
@ -778,7 +789,7 @@ def rescore_with_attention_decoder(
# Remove the seq axis. # Remove the seq axis.
# Now unique_word_seq has only two axes [path][word] # Now unique_word_seq has only two axes [path][word]
unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) unique_word_seq = unique_word_seq.remove_axis(0)
# word_fsa is an FsaVec with axes [path][state][arc] # word_fsa is an FsaVec with axes [path][state][arc]
word_fsa = k2.linear_fsa(unique_word_seq) word_fsa = k2.linear_fsa(unique_word_seq)
@ -796,20 +807,23 @@ def rescore_with_attention_decoder(
# CAUTION: The "tokens" attribute is set in the file # CAUTION: The "tokens" attribute is set in the file
# local/compile_hlg.py # local/compile_hlg.py
token_seq = k2.index(lattice.tokens, path) if isinstance(lattice.tokens, torch.Tensor):
token_seq = k2.ragged.index(lattice.tokens, path)
else:
token_seq = lattice.tokens.index(path, remove_axis=True)
# Remove epsilons and -1 from token_seq # Remove epsilons and -1 from token_seq
token_seq = k2.ragged.remove_values_leq(token_seq, 0) token_seq = token_seq.remove_values_leq(0)
# Remove the seq axis. # Remove the seq axis.
token_seq = k2.ragged.remove_axis(token_seq, 0) token_seq = token_seq.remove_axis(0)
token_seq, _ = k2.ragged.index( token_seq, _ = token_seq.index(
token_seq, indexes=new2old, axis=0, need_value_indexes=False indexes=new2old, axis=0, need_value_indexes=False
) )
# Now word in unique_word_seq has its corresponding token IDs. # Now word in unique_word_seq has its corresponding token IDs.
token_ids = k2.ragged.to_list(token_seq) token_ids = token_seq.tolist()
num_word_seqs = new2old.numel() num_word_seqs = new2old.numel()
@ -849,7 +863,7 @@ def rescore_with_attention_decoder(
else: else:
attention_scale_list = [attention_scale] attention_scale_list = [attention_scale]
path_2axes = k2.ragged.remove_axis(path, 0) path_2axes = path.remove_axis(0)
ans = dict() ans = dict()
for n_scale in ngram_lm_scale_list: for n_scale in ngram_lm_scale_list:
@ -859,23 +873,28 @@ def rescore_with_attention_decoder(
+ n_scale * ngram_lm_scores + n_scale * ngram_lm_scores
+ a_scale * attention_scores + a_scale * attention_scores
) )
ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores) ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores)
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) argmax_indexes = ragged_tot_scores.argmax()
best_path_indexes = k2.index(new2old, argmax_indexes) best_path_indexes = k2.index_select(new2old, argmax_indexes)
# best_path is a k2.RaggedInt with 2 axes [path][arc_pos] # best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
best_path = k2.index(path_2axes, best_path_indexes) best_path, _ = path_2axes.index(
indexes=best_path_indexes, axis=0, need_value_indexes=False
)
# labels is a k2.RaggedInt with 2 axes [path][token_id] # labels is a k2.RaggedTensor with 2 axes [path][token_id]
# Note that it contains -1s. # Note that it contains -1s.
labels = k2.index(lattice.labels.contiguous(), best_path) labels = k2.ragged.index(lattice.labels.contiguous(), best_path)
labels = k2.ragged.remove_values_eq(labels, -1) labels = labels.remove_values_eq(-1)
# lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so if isinstance(lattice.aux_labels, torch.Tensor):
# aux_labels is also a k2.RaggedInt with 2 axes aux_labels = k2.index_select(lattice.aux_labels, best_path.data)
aux_labels = k2.index(lattice.aux_labels, best_path.values()) else:
aux_labels, _ = lattice.aux_labels.index(
indexes=best_path.data, axis=0, need_value_indexes=False
)
best_path_fsa = k2.linear_fsa(labels) best_path_fsa = k2.linear_fsa(labels)
best_path_fsa.aux_labels = aux_labels best_path_fsa.aux_labels = aux_labels

View File

@ -157,7 +157,7 @@ class BpeLexicon(Lexicon):
lang_dir / "lexicon.txt" lang_dir / "lexicon.txt"
) )
def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedInt: def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedTensor:
"""Read a BPE lexicon from file and convert it to a """Read a BPE lexicon from file and convert it to a
k2 ragged tensor. k2 ragged tensor.
@ -200,19 +200,18 @@ class BpeLexicon(Lexicon):
) )
values = torch.tensor(token_ids, dtype=torch.int32) values = torch.tensor(token_ids, dtype=torch.int32)
return k2.RaggedInt(shape, values) return k2.RaggedTensor(shape, values)
def words_to_piece_ids(self, words: List[str]) -> k2.RaggedInt: def words_to_piece_ids(self, words: List[str]) -> k2.RaggedTensor:
"""Convert a list of words to a ragged tensor contained """Convert a list of words to a ragged tensor contained
word piece IDs. word piece IDs.
""" """
word_ids = [self.word_table[w] for w in words] word_ids = [self.word_table[w] for w in words]
word_ids = torch.tensor(word_ids, dtype=torch.int32) word_ids = torch.tensor(word_ids, dtype=torch.int32)
ragged, _ = k2.ragged.index( ragged, _ = self.ragged_lexicon.index(
self.ragged_lexicon,
indexes=word_ids, indexes=word_ids,
need_value_indexes=False,
axis=0, axis=0,
need_value_indexes=False,
) )
return ragged return ragged

View File

@ -26,7 +26,6 @@ from pathlib import Path
from typing import Dict, Iterable, List, TextIO, Tuple, Union from typing import Dict, Iterable, List, TextIO, Tuple, Union
import k2 import k2
import k2.ragged as k2r
import kaldialign import kaldialign
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -199,26 +198,25 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
Returns a list of lists of int, containing the label sequences we Returns a list of lists of int, containing the label sequences we
decoded. decoded.
""" """
if isinstance(best_paths.aux_labels, k2.RaggedInt): if isinstance(best_paths.aux_labels, k2.RaggedTensor):
# remove 0's and -1's. # remove 0's and -1's.
aux_labels = k2r.remove_values_leq(best_paths.aux_labels, 0) aux_labels = best_paths.aux_labels.remove_values_leq(0)
aux_shape = k2r.compose_ragged_shapes( # TODO: change arcs.shape() to arcs.shape
best_paths.arcs.shape(), aux_labels.shape() aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
)
# remove the states and arcs axes. # remove the states and arcs axes.
aux_shape = k2r.remove_axis(aux_shape, 1) aux_shape = aux_shape.remove_axis(1)
aux_shape = k2r.remove_axis(aux_shape, 1) aux_shape = aux_shape.remove_axis(1)
aux_labels = k2.RaggedInt(aux_shape, aux_labels.values()) aux_labels = k2.RaggedTensor(aux_shape, aux_labels.data)
else: else:
# remove axis corresponding to states. # remove axis corresponding to states.
aux_shape = k2r.remove_axis(best_paths.arcs.shape(), 1) aux_shape = best_paths.arcs.shape().remove_axis(1)
aux_labels = k2.RaggedInt(aux_shape, best_paths.aux_labels) aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
# remove 0's and -1's. # remove 0's and -1's.
aux_labels = k2r.remove_values_leq(aux_labels, 0) aux_labels = aux_labels.remove_values_leq(0)
assert aux_labels.num_axes() == 2 assert aux_labels.num_axes == 2
return k2r.to_list(aux_labels) return aux_labels.tolist()
def store_transcripts( def store_transcripts(

View File

@ -16,9 +16,10 @@
# limitations under the License. # limitations under the License.
from pathlib import Path
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.lexicon import BpeLexicon from icefall.lexicon import BpeLexicon
from pathlib import Path
def test(): def test():

View File

@ -60,7 +60,7 @@ def test_get_texts_ragged():
4 4
""" """
) )
fsa1.aux_labels = k2.RaggedInt("[ [1 3 0 2] [] [4 0 1] [-1]]") fsa1.aux_labels = k2.RaggedTensor("[ [1 3 0 2] [] [4 0 1] [-1]]")
fsa2 = k2.Fsa.from_str( fsa2 = k2.Fsa.from_str(
""" """
@ -70,7 +70,7 @@ def test_get_texts_ragged():
3 3
""" """
) )
fsa2.aux_labels = k2.RaggedInt("[[3 0 5 0 8] [0 9 7 0] [-1]]") fsa2.aux_labels = k2.RaggedTensor("[[3 0 5 0 8] [0 9 7 0] [-1]]")
fsas = k2.Fsa.from_fsas([fsa1, fsa2]) fsas = k2.Fsa.from_fsas([fsa1, fsa2])
texts = get_texts(fsas) texts = get_texts(fsas)
assert texts == [[1, 3, 2, 4, 1], [3, 5, 8, 9, 7]] assert texts == [[1, 3, 2, 4, 1], [3, 5, 8, 9, 7]]