Use new APIs with k2.RaggedTensor (#38)

* Use new APIs with k2.RaggedTensor

* Fix style issues.

* Update the installation doc, saying it requires at least k2 v1.7

* Use k2 v1.7
This commit is contained in:
Fangjun Kuang 2021-09-08 14:55:30 +08:00 committed by GitHub
parent 331e5eb7ab
commit abadc71415
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 197 additions and 147 deletions

View File

@ -56,7 +56,7 @@ jobs:
run: | run: |
python3 -m pip install --upgrade pip black flake8 python3 -m pip install --upgrade pip black flake8
python3 -m pip install -U pip python3 -m pip install -U pip
python3 -m pip install k2==1.4.dev20210822+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/ python3 -m pip install k2==1.7.dev20210908+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/
python3 -m pip install torchaudio==0.7.2 python3 -m pip install torchaudio==0.7.2
python3 -m pip install git+https://github.com/lhotse-speech/lhotse python3 -m pip install git+https://github.com/lhotse-speech/lhotse

View File

@ -32,7 +32,8 @@ jobs:
os: [ubuntu-18.04, macos-10.15] os: [ubuntu-18.04, macos-10.15]
python-version: [3.6, 3.7, 3.8, 3.9] python-version: [3.6, 3.7, 3.8, 3.9]
torch: ["1.8.1"] torch: ["1.8.1"]
k2-version: ["1.4.dev20210822"] k2-version: ["1.7.dev20210908"]
fail-fast: false fail-fast: false
steps: steps:

2
.gitignore vendored
View File

@ -4,4 +4,4 @@ path.sh
exp exp
exp*/ exp*/
*.pt *.pt
download/ download

View File

@ -16,7 +16,6 @@
import sphinx_rtd_theme import sphinx_rtd_theme
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
project = "icefall" project = "icefall"

View File

@ -1 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="122" height="20" role="img" aria-label="device: CPU | CUDA"><title>device: CPU | CUDA</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="122" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="45" height="20" fill="#555"/><rect x="45" width="77" height="20" fill="#fe7d37"/><rect width="122" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="235" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="350">device</text><text x="235" y="140" transform="scale(.1)" fill="#fff" textLength="350">device</text><text aria-hidden="true" x="825" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="670">CPU | CUDA</text><text x="825" y="140" transform="scale(.1)" fill="#fff" textLength="670">CPU | CUDA</text></g></svg> <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="122" height="20" role="img" aria-label="device: CPU | CUDA"><title>device: CPU | CUDA</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="122" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="45" height="20" fill="#555"/><rect x="45" width="77" height="20" fill="#fe7d37"/><rect width="122" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="235" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="350">device</text><text x="235" y="140" transform="scale(.1)" fill="#fff" textLength="350">device</text><text aria-hidden="true" x="825" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="670">CPU | CUDA</text><text x="825" y="140" transform="scale(.1)" fill="#fff" textLength="670">CPU | CUDA</text></g></svg>

Before

Width:  |  Height:  |  Size: 1.1 KiB

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="80" height="20" role="img" aria-label="k2: &gt;= v1.7"><title>k2: &gt;= v1.7</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="80" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="23" height="20" fill="#555"/><rect x="23" width="57" height="20" fill="blueviolet"/><rect width="80" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="125" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="130">k2</text><text x="125" y="140" transform="scale(.1)" fill="#fff" textLength="130">k2</text><text aria-hidden="true" x="505" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="470">&gt;= v1.7</text><text x="505" y="140" transform="scale(.1)" fill="#fff" textLength="470">&gt;= v1.7</text></g></svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -1 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="114" height="20" role="img" aria-label="os: Linux | macOS"><title>os: Linux | macOS</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="114" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="23" height="20" fill="#555"/><rect x="23" width="91" height="20" fill="#ff69b4"/><rect width="114" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="125" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="130">os</text><text x="125" y="140" transform="scale(.1)" fill="#fff" textLength="130">os</text><text aria-hidden="true" x="675" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="810">Linux | macOS</text><text x="675" y="140" transform="scale(.1)" fill="#fff" textLength="810">Linux | macOS</text></g></svg> <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="114" height="20" role="img" aria-label="os: Linux | macOS"><title>os: Linux | macOS</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="114" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="23" height="20" fill="#555"/><rect x="23" width="91" height="20" fill="#ff69b4"/><rect width="114" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="125" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="130">os</text><text x="125" y="140" transform="scale(.1)" fill="#fff" textLength="130">os</text><text aria-hidden="true" x="675" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="810">Linux | macOS</text><text x="675" y="140" transform="scale(.1)" fill="#fff" textLength="810">Linux | macOS</text></g></svg>

Before

Width:  |  Height:  |  Size: 1.1 KiB

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -1 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="170" height="20" role="img" aria-label="python: 3.6 | 3.7 | 3.8 | 3.9"><title>python: 3.6 | 3.7 | 3.8 | 3.9</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="170" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="49" height="20" fill="#555"/><rect x="49" width="121" height="20" fill="#007ec6"/><rect width="170" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="255" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="390">python</text><text x="255" y="140" transform="scale(.1)" fill="#fff" textLength="390">python</text><text aria-hidden="true" x="1085" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="1110">3.6 | 3.7 | 3.8 | 3.9</text><text x="1085" y="140" transform="scale(.1)" fill="#fff" textLength="1110">3.6 | 3.7 | 3.8 | 3.9</text></g></svg> <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="170" height="20" role="img" aria-label="python: 3.6 | 3.7 | 3.8 | 3.9"><title>python: 3.6 | 3.7 | 3.8 | 3.9</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="170" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="49" height="20" fill="#555"/><rect x="49" width="121" height="20" fill="#007ec6"/><rect width="170" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="255" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="390">python</text><text x="255" y="140" transform="scale(.1)" fill="#fff" textLength="390">python</text><text aria-hidden="true" x="1085" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="1110">3.6 | 3.7 | 3.8 | 3.9</text><text x="1085" y="140" transform="scale(.1)" fill="#fff" textLength="1110">3.6 | 3.7 | 3.8 | 3.9</text></g></svg>

Before

Width:  |  Height:  |  Size: 1.2 KiB

After

Width:  |  Height:  |  Size: 1.2 KiB

View File

@ -1 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="286" height="20" role="img" aria-label="torch: 1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0"><title>torch: 1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="286" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="39" height="20" fill="#555"/><rect x="39" width="247" height="20" fill="#97ca00"/><rect width="286" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="205" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="290">torch</text><text x="205" y="140" transform="scale(.1)" fill="#fff" textLength="290">torch</text><text aria-hidden="true" x="1615" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="2370">1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0</text><text x="1615" y="140" transform="scale(.1)" fill="#fff" textLength="2370">1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0</text></g></svg> <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="286" height="20" role="img" aria-label="torch: 1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0"><title>torch: 1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="286" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="39" height="20" fill="#555"/><rect x="39" width="247" height="20" fill="#97ca00"/><rect width="286" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="205" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="290">torch</text><text x="205" y="140" transform="scale(.1)" fill="#fff" textLength="290">torch</text><text aria-hidden="true" x="1615" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="2370">1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0</text><text x="1615" y="140" transform="scale(.1)" fill="#fff" textLength="2370">1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0</text></g></svg>

Before

Width:  |  Height:  |  Size: 1.3 KiB

After

Width:  |  Height:  |  Size: 1.3 KiB

View File

@ -7,6 +7,7 @@ Installation
- |device| - |device|
- |python_versions| - |python_versions|
- |torch_versions| - |torch_versions|
- |k2_versions|
.. |os| image:: ./images/os-Linux_macOS-ff69b4.svg .. |os| image:: ./images/os-Linux_macOS-ff69b4.svg
:alt: Supported operating systems :alt: Supported operating systems
@ -20,7 +21,10 @@ Installation
.. |torch_versions| image:: ./images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg .. |torch_versions| image:: ./images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg
:alt: Supported PyTorch versions :alt: Supported PyTorch versions
icefall depends on `k2 <https://github.com/k2-fsa/k2>`_ and .. |k2_versions| image:: ./images/k2-v-1.7.svg
:alt: Supported k2 versions
``icefall`` depends on `k2 <https://github.com/k2-fsa/k2>`_ and
`lhotse <https://github.com/lhotse-speech/lhotse>`_. `lhotse <https://github.com/lhotse-speech/lhotse>`_.
We recommend you to install ``k2`` first, as ``k2`` is bound to We recommend you to install ``k2`` first, as ``k2`` is bound to
@ -32,12 +36,16 @@ installs its dependency PyTorch, which can be reused by ``lhotse``.
-------------- --------------
Please refer to `<https://k2.readthedocs.io/en/latest/installation/index.html>`_ Please refer to `<https://k2.readthedocs.io/en/latest/installation/index.html>`_
to install `k2`. to install ``k2``.
.. CAUTION::
You need to install ``k2`` with a version at least **v1.7**.
.. HINT:: .. HINT::
If you have already installed PyTorch and don't want to replace it, If you have already installed PyTorch and don't want to replace it,
please install a version of k2 that is compiled against the version please install a version of ``k2`` that is compiled against the version
of PyTorch you are using. of PyTorch you are using.
(2) Install lhotse (2) Install lhotse
@ -50,10 +58,15 @@ to install ``lhotse``.
Install ``lhotse`` also installs its dependency `torchaudio <https://github.com/pytorch/audio>`_. Install ``lhotse`` also installs its dependency `torchaudio <https://github.com/pytorch/audio>`_.
.. CAUTION::
If you have installed ``torchaudio``, please consider uninstalling it before
installing ``lhotse``. Otherwise, it may update your already installed PyTorch.
(3) Download icefall (3) Download icefall
-------------------- --------------------
icefall is a collection of Python scripts, so you don't need to install it ``icefall`` is a collection of Python scripts, so you don't need to install it
and we don't provide a ``setup.py`` to install it. and we don't provide a ``setup.py`` to install it.
What you need is to download it and set the environment variable ``PYTHONPATH`` What you need is to download it and set the environment variable ``PYTHONPATH``
@ -367,7 +380,7 @@ Now let us run the training part:
.. CAUTION:: .. CAUTION::
We use ``export CUDA_VISIBLE_DEVICES=""`` so that icefall uses CPU We use ``export CUDA_VISIBLE_DEVICES=""`` so that ``icefall`` uses CPU
even if there are GPUs available. even if there are GPUs available.
The training log is given below: The training log is given below:

View File

@ -15,4 +15,3 @@ We may add recipes for other tasks as well in the future.
yesno yesno
librispeech librispeech

View File

@ -209,7 +209,7 @@ After downloading, you will have the following files:
|-- 1221-135766-0001.flac |-- 1221-135766-0001.flac
|-- 1221-135766-0002.flac |-- 1221-135766-0002.flac
`-- trans.txt `-- trans.txt
6 directories, 10 files 6 directories, 10 files
@ -256,14 +256,14 @@ The output is:
2021-08-24 16:57:28,098 INFO [pretrained.py:266] 2021-08-24 16:57:28,098 INFO [pretrained.py:266]
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac: ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac: ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac: ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-24 16:57:28,099 INFO [pretrained.py:268] Decoding Done 2021-08-24 16:57:28,099 INFO [pretrained.py:268] Decoding Done
@ -297,14 +297,14 @@ The decoding output is:
2021-08-24 16:39:54,010 INFO [pretrained.py:266] 2021-08-24 16:39:54,010 INFO [pretrained.py:266]
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac: ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac: ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac: ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-24 16:39:54,010 INFO [pretrained.py:268] Decoding Done 2021-08-24 16:39:54,010 INFO [pretrained.py:268] Decoding Done

View File

@ -43,4 +43,3 @@ We searched the lm_score_scale for best results, the scales that produced the WE
|--|--| |--|--|
|test-clean|0.8| |test-clean|0.8|
|test-other|0.9| |test-other|0.9|

View File

@ -45,6 +45,7 @@ from icefall.utils import (
get_texts, get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool,
write_error_stats, write_error_stats,
) )
@ -116,6 +117,17 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
return parser return parser
@ -541,6 +553,13 @@ def main():
logging.info(f"averaging {filenames}") logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames)) model.load_state_dict(average_checkpoints(filenames))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return
model.to(device) model.to(device)
model.eval() model.eval()
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])

View File

@ -16,9 +16,8 @@
# limitations under the License. # limitations under the License.
from subsampling import Conv2dSubsampling
from subsampling import VggSubsampling
import torch import torch
from subsampling import Conv2dSubsampling, VggSubsampling
def test_conv2d_subsampling(): def test_conv2d_subsampling():

View File

@ -17,17 +17,16 @@
import torch import torch
from torch.nn.utils.rnn import pad_sequence
from transformer import ( from transformer import (
Transformer, Transformer,
add_eos,
add_sos,
decoder_padding_mask,
encoder_padding_mask, encoder_padding_mask,
generate_square_subsequent_mask, generate_square_subsequent_mask,
decoder_padding_mask,
add_sos,
add_eos,
) )
from torch.nn.utils.rnn import pad_sequence
def test_encoder_padding_mask(): def test_encoder_padding_mask():
supervisions = { supervisions = {

View File

@ -102,14 +102,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
LG.labels[LG.labels >= first_token_disambig_id] = 0 LG.labels[LG.labels >= first_token_disambig_id] = 0
assert isinstance(LG.aux_labels, k2.RaggedInt) assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0 LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG) LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG) LG = k2.connect(LG)
LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG") logging.info("Arc sorting LG")
LG = k2.arc_sort(LG) LG = k2.arc_sort(LG)

View File

@ -99,8 +99,10 @@ def get_params() -> AttributeDict:
# - nbest-rescoring # - nbest-rescoring
# - whole-lattice-rescoring # - whole-lattice-rescoring
"method": "whole-lattice-rescoring", "method": "whole-lattice-rescoring",
# "method": "1best",
# "method": "nbest",
# num_paths is used when method is "nbest" and "nbest-rescoring" # num_paths is used when method is "nbest" and "nbest-rescoring"
"num_paths": 30, "num_paths": 100,
} }
) )
return params return params
@ -424,6 +426,7 @@ def main():
torch.save( torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
) )
return
model.to(device) model.to(device)
model.eval() model.eval()

0
egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py Normal file → Executable file
View File

View File

@ -80,14 +80,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
LG.labels[LG.labels >= first_token_disambig_id] = 0 LG.labels[LG.labels >= first_token_disambig_id] = 0
assert isinstance(LG.aux_labels, k2.RaggedInt) assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0 LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG) LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG) LG = k2.connect(LG)
LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG") logging.info("Arc sorting LG")
LG = k2.arc_sort(LG) LG = k2.arc_sort(LG)

View File

@ -296,6 +296,7 @@ def main():
torch.save( torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
) )
return
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -84,8 +84,8 @@ def _intersect_device(
for start, end in splits: for start, end in splits:
indexes = torch.arange(start, end).to(b_to_a_map) indexes = torch.arange(start, end).to(b_to_a_map)
fsas = k2.index(b_fsas, indexes) fsas = k2.index_fsa(b_fsas, indexes)
b_to_a = k2.index(b_to_a_map, indexes) b_to_a = k2.index_select(b_to_a_map, indexes)
path_lattice = k2.intersect_device( path_lattice = k2.intersect_device(
a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a
) )
@ -215,18 +215,16 @@ def nbest_decoding(
scale=scale, scale=scale,
) )
# word_seq is a k2.RaggedInt sharing the same shape as `path` # word_seq is a k2.RaggedTensor sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s. # but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1. # The last entry in each sublist is -1.
word_seq = k2.index(lattice.aux_labels, path) if isinstance(lattice.aux_labels, torch.Tensor):
# Note: the above operation supports also the case when word_seq = k2.ragged.index(lattice.aux_labels, path)
# lattice.aux_labels is a ragged tensor. In that case, else:
# `remove_axis=True` is used inside the pybind11 binding code, word_seq = lattice.aux_labels.index(path, remove_axis=True)
# so the resulting `word_seq` still has 3 axes, like `path`.
# The 3 axes are [seq][path][word_id]
# Remove 0 (epsilon) and -1 from word_seq # Remove 0 (epsilon) and -1 from word_seq
word_seq = k2.ragged.remove_values_leq(word_seq, 0) word_seq = word_seq.remove_values_leq(0)
# Remove sequences with identical word sequences. # Remove sequences with identical word sequences.
# #
@ -234,12 +232,12 @@ def nbest_decoding(
# `new2old` is a 1-D torch.Tensor mapping from the output path index # `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index. # to the input path index.
# new2old.numel() == unique_word_seqs.tot_size(1) # new2old.numel() == unique_word_seqs.tot_size(1)
unique_word_seq, _, new2old = k2.ragged.unique_sequences( unique_word_seq, _, new2old = word_seq.unique(
word_seq, need_num_repeats=False, need_new2old_indexes=True need_num_repeats=False, need_new2old_indexes=True
) )
# Note: unique_word_seq still has the same axes as word_seq # Note: unique_word_seq still has the same axes as word_seq
seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) seq_to_path_shape = unique_word_seq.shape.get_layer(0)
# path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map is a 1-D torch.Tensor.
# path_to_seq_map[i] is the seq to which the i-th path belongs # path_to_seq_map[i] is the seq to which the i-th path belongs
@ -247,7 +245,7 @@ def nbest_decoding(
# Remove the seq axis. # Remove the seq axis.
# Now unique_word_seq has only two axes [path][word] # Now unique_word_seq has only two axes [path][word]
unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) unique_word_seq = unique_word_seq.remove_axis(0)
# word_fsa is an FsaVec with axes [path][state][arc] # word_fsa is an FsaVec with axes [path][state][arc]
word_fsa = k2.linear_fsa(unique_word_seq) word_fsa = k2.linear_fsa(unique_word_seq)
@ -275,35 +273,35 @@ def nbest_decoding(
use_double_scores=use_double_scores, log_semiring=False use_double_scores=use_double_scores, log_semiring=False
) )
# RaggedFloat currently supports float32 only. ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores)
# If Ragged<double> is wrapped, we can use k2.RaggedDouble here
ragged_tot_scores = k2.RaggedFloat(
seq_to_path_shape, tot_scores.to(torch.float32)
)
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) argmax_indexes = ragged_tot_scores.argmax()
# Since we invoked `k2.ragged.unique_sequences`, which reorders # Since we invoked `k2.ragged.unique_sequences`, which reorders
# the index from `path`, we use `new2old` here to convert argmax_indexes # the index from `path`, we use `new2old` here to convert argmax_indexes
# to the indexes into `path`. # to the indexes into `path`.
# #
# Use k2.index here since argmax_indexes' dtype is torch.int32 # Use k2.index here since argmax_indexes' dtype is torch.int32
best_path_indexes = k2.index(new2old, argmax_indexes) best_path_indexes = k2.index_select(new2old, argmax_indexes)
path_2axes = k2.ragged.remove_axis(path, 0) path_2axes = path.remove_axis(0)
# best_path is a k2.RaggedInt with 2 axes [path][arc_pos] # best_path is a k2.RaggedTensor with 2 axes [path][arc_pos]
best_path = k2.index(path_2axes, best_path_indexes) best_path, _ = path_2axes.index(
indexes=best_path_indexes, axis=0, need_value_indexes=False
)
# labels is a k2.RaggedInt with 2 axes [path][token_id] # labels is a k2.RaggedTensor with 2 axes [path][token_id]
# Note that it contains -1s. # Note that it contains -1s.
labels = k2.index(lattice.labels.contiguous(), best_path) labels = k2.ragged.index(lattice.labels.contiguous(), best_path)
labels = k2.ragged.remove_values_eq(labels, -1) labels = labels.remove_values_eq(-1)
# lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so # lattice.aux_labels is a k2.RaggedTensor with 2 axes, so
# aux_labels is also a k2.RaggedInt with 2 axes # aux_labels is also a k2.RaggedTensor with 2 axes
aux_labels = k2.index(lattice.aux_labels, best_path.values()) aux_labels, _ = lattice.aux_labels.index(
indexes=best_path.data, axis=0, need_value_indexes=False
)
best_path_fsa = k2.linear_fsa(labels) best_path_fsa = k2.linear_fsa(labels)
best_path_fsa.aux_labels = aux_labels best_path_fsa.aux_labels = aux_labels
@ -426,33 +424,36 @@ def rescore_with_n_best_list(
scale=scale, scale=scale,
) )
# word_seq is a k2.RaggedInt sharing the same shape as `path` # word_seq is a k2.RaggedTensor sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s. # but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1. # The last entry in each sublist is -1.
word_seq = k2.index(lattice.aux_labels, path) if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path, remove_axis=True)
# Remove epsilons and -1 from word_seq # Remove epsilons and -1 from word_seq
word_seq = k2.ragged.remove_values_leq(word_seq, 0) word_seq = word_seq.remove_values_leq(0)
# Remove paths that has identical word sequences. # Remove paths that has identical word sequences.
# #
# unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word] # unique_word_seq is still a k2.RaggedTensor with 3 axes [seq][path][word]
# except that there are no repeated paths with the same word_seq # except that there are no repeated paths with the same word_seq
# within a sequence. # within a sequence.
# #
# num_repeats is also a k2.RaggedInt with 2 axes containing the # num_repeats is also a k2.RaggedTensor with 2 axes containing the
# multiplicities of each path. # multiplicities of each path.
# num_repeats.num_elements() == unique_word_seqs.tot_size(1) # num_repeats.numel() == unique_word_seqs.tot_size(1)
# #
# Since k2.ragged.unique_sequences will reorder paths within a seq, # Since k2.ragged.unique_sequences will reorder paths within a seq,
# `new2old` is a 1-D torch.Tensor mapping from the output path index # `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index. # to the input path index.
# new2old.numel() == unique_word_seqs.tot_size(1) # new2old.numel() == unique_word_seqs.tot_size(1)
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences( unique_word_seq, num_repeats, new2old = word_seq.unique(
word_seq, need_num_repeats=True, need_new2old_indexes=True need_num_repeats=True, need_new2old_indexes=True
) )
seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) seq_to_path_shape = unique_word_seq.shape.get_layer(0)
# path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map is a 1-D torch.Tensor.
# path_to_seq_map[i] is the seq to which the i-th path # path_to_seq_map[i] is the seq to which the i-th path
@ -461,7 +462,7 @@ def rescore_with_n_best_list(
# Remove the seq axis. # Remove the seq axis.
# Now unique_word_seq has only two axes [path][word] # Now unique_word_seq has only two axes [path][word]
unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) unique_word_seq = unique_word_seq.remove_axis(0)
# word_fsa is an FsaVec with axes [path][state][arc] # word_fsa is an FsaVec with axes [path][state][arc]
word_fsa = k2.linear_fsa(unique_word_seq) word_fsa = k2.linear_fsa(unique_word_seq)
@ -485,39 +486,42 @@ def rescore_with_n_best_list(
use_double_scores=True, log_semiring=False use_double_scores=True, log_semiring=False
) )
path_2axes = k2.ragged.remove_axis(path, 0) path_2axes = path.remove_axis(0)
ans = dict() ans = dict()
for lm_scale in lm_scale_list: for lm_scale in lm_scale_list:
tot_scores = am_scores / lm_scale + lm_scores tot_scores = am_scores / lm_scale + lm_scores
# Remember that we used `k2.ragged.unique_sequences` to remove repeated # Remember that we used `k2.RaggedTensor.unique` to remove repeated
# paths to avoid redundant computation in `k2.intersect_device`. # paths to avoid redundant computation in `k2.intersect_device`.
# Now we use `num_repeats` to correct the scores for each path. # Now we use `num_repeats` to correct the scores for each path.
# #
# NOTE(fangjun): It is commented out as it leads to a worse WER # NOTE(fangjun): It is commented out as it leads to a worse WER
# tot_scores = tot_scores * num_repeats.values() # tot_scores = tot_scores * num_repeats.values()
ragged_tot_scores = k2.RaggedFloat( ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores)
seq_to_path_shape, tot_scores.to(torch.float32) argmax_indexes = ragged_tot_scores.argmax()
)
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
# Use k2.index here since argmax_indexes' dtype is torch.int32 # Use k2.index here since argmax_indexes' dtype is torch.int32
best_path_indexes = k2.index(new2old, argmax_indexes) best_path_indexes = k2.index_select(new2old, argmax_indexes)
# best_path is a k2.RaggedInt with 2 axes [path][arc_pos] # best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
best_path = k2.index(path_2axes, best_path_indexes) best_path, _ = path_2axes.index(
indexes=best_path_indexes, axis=0, need_value_indexes=False
)
# labels is a k2.RaggedInt with 2 axes [path][phone_id] # labels is a k2.RaggedTensor with 2 axes [path][phone_id]
# Note that it contains -1s. # Note that it contains -1s.
labels = k2.index(lattice.labels.contiguous(), best_path) labels = k2.ragged.index(lattice.labels.contiguous(), best_path)
labels = k2.ragged.remove_values_eq(labels, -1) labels = labels.remove_values_eq(-1)
# lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so # lattice.aux_labels is a k2.RaggedTensor tensor with 2 axes, so
# aux_labels is also a k2.RaggedInt with 2 axes # aux_labels is also a k2.RaggedTensor with 2 axes
aux_labels = k2.index(lattice.aux_labels, best_path.values())
aux_labels, _ = lattice.aux_labels.index(
indexes=best_path.data, axis=0, need_value_indexes=False
)
best_path_fsa = k2.linear_fsa(labels) best_path_fsa = k2.linear_fsa(labels)
best_path_fsa.aux_labels = aux_labels best_path_fsa.aux_labels = aux_labels
@ -659,12 +663,16 @@ def nbest_oracle(
scale=scale, scale=scale,
) )
word_seq = k2.index(lattice.aux_labels, path) if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.remove_values_leq(word_seq, 0) word_seq = k2.ragged.index(lattice.aux_labels, path)
unique_word_seq, _, _ = k2.ragged.unique_sequences( else:
word_seq, need_num_repeats=False, need_new2old_indexes=False word_seq = lattice.aux_labels.index(path, remove_axis=True)
word_seq = word_seq.remove_values_leq(0)
unique_word_seq, _, _ = word_seq.unique(
need_num_repeats=False, need_new2old_indexes=False
) )
unique_word_ids = k2.ragged.to_list(unique_word_seq) unique_word_ids = unique_word_seq.tolist()
assert len(unique_word_ids) == len(ref_texts) assert len(unique_word_ids) == len(ref_texts)
# unique_word_ids[i] contains all hypotheses of the i-th utterance # unique_word_ids[i] contains all hypotheses of the i-th utterance
@ -743,33 +751,36 @@ def rescore_with_attention_decoder(
scale=scale, scale=scale,
) )
# word_seq is a k2.RaggedInt sharing the same shape as `path` # word_seq is a k2.RaggedTensor sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s. # but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1. # The last entry in each sublist is -1.
word_seq = k2.index(lattice.aux_labels, path) if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path, remove_axis=True)
# Remove epsilons and -1 from word_seq # Remove epsilons and -1 from word_seq
word_seq = k2.ragged.remove_values_leq(word_seq, 0) word_seq = word_seq.remove_values_leq(0)
# Remove paths that has identical word sequences. # Remove paths that has identical word sequences.
# #
# unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word] # unique_word_seq is still a k2.RaggedTensor with 3 axes [seq][path][word]
# except that there are no repeated paths with the same word_seq # except that there are no repeated paths with the same word_seq
# within a sequence. # within a sequence.
# #
# num_repeats is also a k2.RaggedInt with 2 axes containing the # num_repeats is also a k2.RaggedTensor with 2 axes containing the
# multiplicities of each path. # multiplicities of each path.
# num_repeats.num_elements() == unique_word_seqs.tot_size(1) # num_repeats.numel() == unique_word_seqs.tot_size(1)
# #
# Since k2.ragged.unique_sequences will reorder paths within a seq, # Since k2.ragged.unique_sequences will reorder paths within a seq,
# `new2old` is a 1-D torch.Tensor mapping from the output path index # `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index. # to the input path index.
# new2old.numel() == unique_word_seq.tot_size(1) # new2old.numel() == unique_word_seq.tot_size(1)
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences( unique_word_seq, num_repeats, new2old = word_seq.unique(
word_seq, need_num_repeats=True, need_new2old_indexes=True need_num_repeats=True, need_new2old_indexes=True
) )
seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) seq_to_path_shape = unique_word_seq.shape.get_layer(0)
# path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map is a 1-D torch.Tensor.
# path_to_seq_map[i] is the seq to which the i-th path # path_to_seq_map[i] is the seq to which the i-th path
@ -778,7 +789,7 @@ def rescore_with_attention_decoder(
# Remove the seq axis. # Remove the seq axis.
# Now unique_word_seq has only two axes [path][word] # Now unique_word_seq has only two axes [path][word]
unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) unique_word_seq = unique_word_seq.remove_axis(0)
# word_fsa is an FsaVec with axes [path][state][arc] # word_fsa is an FsaVec with axes [path][state][arc]
word_fsa = k2.linear_fsa(unique_word_seq) word_fsa = k2.linear_fsa(unique_word_seq)
@ -796,20 +807,23 @@ def rescore_with_attention_decoder(
# CAUTION: The "tokens" attribute is set in the file # CAUTION: The "tokens" attribute is set in the file
# local/compile_hlg.py # local/compile_hlg.py
token_seq = k2.index(lattice.tokens, path) if isinstance(lattice.tokens, torch.Tensor):
token_seq = k2.ragged.index(lattice.tokens, path)
else:
token_seq = lattice.tokens.index(path, remove_axis=True)
# Remove epsilons and -1 from token_seq # Remove epsilons and -1 from token_seq
token_seq = k2.ragged.remove_values_leq(token_seq, 0) token_seq = token_seq.remove_values_leq(0)
# Remove the seq axis. # Remove the seq axis.
token_seq = k2.ragged.remove_axis(token_seq, 0) token_seq = token_seq.remove_axis(0)
token_seq, _ = k2.ragged.index( token_seq, _ = token_seq.index(
token_seq, indexes=new2old, axis=0, need_value_indexes=False indexes=new2old, axis=0, need_value_indexes=False
) )
# Now word in unique_word_seq has its corresponding token IDs. # Now word in unique_word_seq has its corresponding token IDs.
token_ids = k2.ragged.to_list(token_seq) token_ids = token_seq.tolist()
num_word_seqs = new2old.numel() num_word_seqs = new2old.numel()
@ -849,7 +863,7 @@ def rescore_with_attention_decoder(
else: else:
attention_scale_list = [attention_scale] attention_scale_list = [attention_scale]
path_2axes = k2.ragged.remove_axis(path, 0) path_2axes = path.remove_axis(0)
ans = dict() ans = dict()
for n_scale in ngram_lm_scale_list: for n_scale in ngram_lm_scale_list:
@ -859,23 +873,28 @@ def rescore_with_attention_decoder(
+ n_scale * ngram_lm_scores + n_scale * ngram_lm_scores
+ a_scale * attention_scores + a_scale * attention_scores
) )
ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores) ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores)
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) argmax_indexes = ragged_tot_scores.argmax()
best_path_indexes = k2.index(new2old, argmax_indexes) best_path_indexes = k2.index_select(new2old, argmax_indexes)
# best_path is a k2.RaggedInt with 2 axes [path][arc_pos] # best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
best_path = k2.index(path_2axes, best_path_indexes) best_path, _ = path_2axes.index(
indexes=best_path_indexes, axis=0, need_value_indexes=False
)
# labels is a k2.RaggedInt with 2 axes [path][token_id] # labels is a k2.RaggedTensor with 2 axes [path][token_id]
# Note that it contains -1s. # Note that it contains -1s.
labels = k2.index(lattice.labels.contiguous(), best_path) labels = k2.ragged.index(lattice.labels.contiguous(), best_path)
labels = k2.ragged.remove_values_eq(labels, -1) labels = labels.remove_values_eq(-1)
# lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so if isinstance(lattice.aux_labels, torch.Tensor):
# aux_labels is also a k2.RaggedInt with 2 axes aux_labels = k2.index_select(lattice.aux_labels, best_path.data)
aux_labels = k2.index(lattice.aux_labels, best_path.values()) else:
aux_labels, _ = lattice.aux_labels.index(
indexes=best_path.data, axis=0, need_value_indexes=False
)
best_path_fsa = k2.linear_fsa(labels) best_path_fsa = k2.linear_fsa(labels)
best_path_fsa.aux_labels = aux_labels best_path_fsa.aux_labels = aux_labels

View File

@ -157,7 +157,7 @@ class BpeLexicon(Lexicon):
lang_dir / "lexicon.txt" lang_dir / "lexicon.txt"
) )
def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedInt: def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedTensor:
"""Read a BPE lexicon from file and convert it to a """Read a BPE lexicon from file and convert it to a
k2 ragged tensor. k2 ragged tensor.
@ -200,19 +200,18 @@ class BpeLexicon(Lexicon):
) )
values = torch.tensor(token_ids, dtype=torch.int32) values = torch.tensor(token_ids, dtype=torch.int32)
return k2.RaggedInt(shape, values) return k2.RaggedTensor(shape, values)
def words_to_piece_ids(self, words: List[str]) -> k2.RaggedInt: def words_to_piece_ids(self, words: List[str]) -> k2.RaggedTensor:
"""Convert a list of words to a ragged tensor contained """Convert a list of words to a ragged tensor contained
word piece IDs. word piece IDs.
""" """
word_ids = [self.word_table[w] for w in words] word_ids = [self.word_table[w] for w in words]
word_ids = torch.tensor(word_ids, dtype=torch.int32) word_ids = torch.tensor(word_ids, dtype=torch.int32)
ragged, _ = k2.ragged.index( ragged, _ = self.ragged_lexicon.index(
self.ragged_lexicon,
indexes=word_ids, indexes=word_ids,
need_value_indexes=False,
axis=0, axis=0,
need_value_indexes=False,
) )
return ragged return ragged

View File

@ -26,7 +26,6 @@ from pathlib import Path
from typing import Dict, Iterable, List, TextIO, Tuple, Union from typing import Dict, Iterable, List, TextIO, Tuple, Union
import k2 import k2
import k2.ragged as k2r
import kaldialign import kaldialign
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -199,26 +198,25 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
Returns a list of lists of int, containing the label sequences we Returns a list of lists of int, containing the label sequences we
decoded. decoded.
""" """
if isinstance(best_paths.aux_labels, k2.RaggedInt): if isinstance(best_paths.aux_labels, k2.RaggedTensor):
# remove 0's and -1's. # remove 0's and -1's.
aux_labels = k2r.remove_values_leq(best_paths.aux_labels, 0) aux_labels = best_paths.aux_labels.remove_values_leq(0)
aux_shape = k2r.compose_ragged_shapes( # TODO: change arcs.shape() to arcs.shape
best_paths.arcs.shape(), aux_labels.shape() aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
)
# remove the states and arcs axes. # remove the states and arcs axes.
aux_shape = k2r.remove_axis(aux_shape, 1) aux_shape = aux_shape.remove_axis(1)
aux_shape = k2r.remove_axis(aux_shape, 1) aux_shape = aux_shape.remove_axis(1)
aux_labels = k2.RaggedInt(aux_shape, aux_labels.values()) aux_labels = k2.RaggedTensor(aux_shape, aux_labels.data)
else: else:
# remove axis corresponding to states. # remove axis corresponding to states.
aux_shape = k2r.remove_axis(best_paths.arcs.shape(), 1) aux_shape = best_paths.arcs.shape().remove_axis(1)
aux_labels = k2.RaggedInt(aux_shape, best_paths.aux_labels) aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
# remove 0's and -1's. # remove 0's and -1's.
aux_labels = k2r.remove_values_leq(aux_labels, 0) aux_labels = aux_labels.remove_values_leq(0)
assert aux_labels.num_axes() == 2 assert aux_labels.num_axes == 2
return k2r.to_list(aux_labels) return aux_labels.tolist()
def store_transcripts( def store_transcripts(

View File

@ -16,9 +16,10 @@
# limitations under the License. # limitations under the License.
from pathlib import Path
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.lexicon import BpeLexicon from icefall.lexicon import BpeLexicon
from pathlib import Path
def test(): def test():

View File

@ -60,7 +60,7 @@ def test_get_texts_ragged():
4 4
""" """
) )
fsa1.aux_labels = k2.RaggedInt("[ [1 3 0 2] [] [4 0 1] [-1]]") fsa1.aux_labels = k2.RaggedTensor("[ [1 3 0 2] [] [4 0 1] [-1]]")
fsa2 = k2.Fsa.from_str( fsa2 = k2.Fsa.from_str(
""" """
@ -70,7 +70,7 @@ def test_get_texts_ragged():
3 3
""" """
) )
fsa2.aux_labels = k2.RaggedInt("[[3 0 5 0 8] [0 9 7 0] [-1]]") fsa2.aux_labels = k2.RaggedTensor("[[3 0 5 0 8] [0 9 7 0] [-1]]")
fsas = k2.Fsa.from_fsas([fsa1, fsa2]) fsas = k2.Fsa.from_fsas([fsa1, fsa2])
texts = get_texts(fsas) texts = get_texts(fsas)
assert texts == [[1, 3, 2, 4, 1], [3, 5, 8, 9, 7]] assert texts == [[1, 3, 2, 4, 1], [3, 5, 8, 9, 7]]