diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml
index 39a6a0e80..876b95e71 100644
--- a/.github/workflows/run-yesno-recipe.yml
+++ b/.github/workflows/run-yesno-recipe.yml
@@ -21,11 +21,11 @@ on:
branches:
- master
pull_request:
- branches:
- - master
+ types: [labeled]
jobs:
run-yesno-recipe:
+ if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
@@ -33,6 +33,8 @@ jobs:
# TODO: enable macOS for CPU testing
os: [ubuntu-18.04]
python-version: [3.8]
+ torch: ["1.8.1"]
+ k2-version: ["1.9.dev20210919"]
fail-fast: false
steps:
@@ -54,10 +56,8 @@ jobs:
- name: Install Python dependencies
run: |
- python3 -m pip install --upgrade pip black flake8
python3 -m pip install -U pip
- python3 -m pip install k2==1.4.dev20210822+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/
- python3 -m pip install torchaudio==0.7.2
+ pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
# We are in ./icefall and there is a file: requirements.txt in it
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 9110e7db4..150b5258a 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -21,18 +21,19 @@ on:
branches:
- master
pull_request:
- branches:
- - master
+ types: [labeled]
jobs:
test:
+ if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04, macos-10.15]
python-version: [3.6, 3.7, 3.8, 3.9]
torch: ["1.8.1"]
- k2-version: ["1.4.dev20210822"]
+ k2-version: ["1.9.dev20210919"]
+
fail-fast: false
steps:
@@ -52,6 +53,20 @@ jobs:
# icefall requirements
pip install -r requirements.txt
+ - name: Install graphviz
+ if: startsWith(matrix.os, 'ubuntu')
+ shell: bash
+ run: |
+ python3 -m pip install -qq graphviz
+ sudo apt-get -qq install graphviz
+
+ - name: Install graphviz
+ if: startsWith(matrix.os, 'macos')
+ shell: bash
+ run: |
+ python3 -m pip install -qq graphviz
+ brew install -q graphviz
+
- name: Run tests
if: startsWith(matrix.os, 'ubuntu')
run: |
diff --git a/.gitignore b/.gitignore
index 839a1c34a..e6c84ca5e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,4 +4,4 @@ path.sh
exp
exp*/
*.pt
-download/
+download
diff --git a/docs/source/conf.py b/docs/source/conf.py
index f97f72d54..599df8b3e 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -16,7 +16,6 @@
import sphinx_rtd_theme
-
# -- Project information -----------------------------------------------------
project = "icefall"
diff --git a/docs/source/contributing/how-to-create-a-recipe.rst b/docs/source/contributing/how-to-create-a-recipe.rst
index 2d53fd89f..a30fb9056 100644
--- a/docs/source/contributing/how-to-create-a-recipe.rst
+++ b/docs/source/contributing/how-to-create-a-recipe.rst
@@ -56,7 +56,7 @@ organize your files in the following way:
$ cd egs/foo/ASR
$ mkdir bar
$ cd bar
- $ tourch README.md model.py train.py decode.py asr_datamodule.py pretrained.py
+ $ touch README.md model.py train.py decode.py asr_datamodule.py pretrained.py
For instance , the ``yesno`` recipe has a ``tdnn`` model and its directory structure
looks like the following:
diff --git a/docs/source/installation/images/device-CPU_CUDA-orange.svg b/docs/source/installation/images/device-CPU_CUDA-orange.svg
index b760102e3..a023a1283 100644
--- a/docs/source/installation/images/device-CPU_CUDA-orange.svg
+++ b/docs/source/installation/images/device-CPU_CUDA-orange.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
diff --git a/docs/source/installation/images/k2-v1.9-blueviolet.svg b/docs/source/installation/images/k2-v1.9-blueviolet.svg
new file mode 100644
index 000000000..5a207b370
--- /dev/null
+++ b/docs/source/installation/images/k2-v1.9-blueviolet.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/source/installation/images/os-Linux_macOS-ff69b4.svg b/docs/source/installation/images/os-Linux_macOS-ff69b4.svg
index 44c112747..178813ed4 100644
--- a/docs/source/installation/images/os-Linux_macOS-ff69b4.svg
+++ b/docs/source/installation/images/os-Linux_macOS-ff69b4.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
diff --git a/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg b/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg
index 676feba2c..befc1e19e 100644
--- a/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg
+++ b/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
diff --git a/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg b/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg
index d9b0efe1a..496e5a9ef 100644
--- a/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg
+++ b/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst
index bcef669c8..f960033e8 100644
--- a/docs/source/installation/index.rst
+++ b/docs/source/installation/index.rst
@@ -7,6 +7,7 @@ Installation
- |device|
- |python_versions|
- |torch_versions|
+- |k2_versions|
.. |os| image:: ./images/os-Linux_macOS-ff69b4.svg
:alt: Supported operating systems
@@ -20,7 +21,10 @@ Installation
.. |torch_versions| image:: ./images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg
:alt: Supported PyTorch versions
-icefall depends on `k2 `_ and
+.. |k2_versions| image:: ./images/k2-v1.9-blueviolet.svg
+ :alt: Supported k2 versions
+
+``icefall`` depends on `k2 `_ and
`lhotse `_.
We recommend you to install ``k2`` first, as ``k2`` is bound to
@@ -31,13 +35,17 @@ installs its dependency PyTorch, which can be reused by ``lhotse``.
(1) Install k2
--------------
-Please refer to ``_
-to install `k2`.
+Please refer to ``_
+to install ``k2``.
+
+.. CAUTION::
+
+ You need to install ``k2`` with a version at least **v1.9**.
.. HINT::
If you have already installed PyTorch and don't want to replace it,
- please install a version of k2 that is compiled against the version
+ please install a version of ``k2`` that is compiled against the version
of PyTorch you are using.
(2) Install lhotse
@@ -50,10 +58,15 @@ to install ``lhotse``.
Install ``lhotse`` also installs its dependency `torchaudio `_.
+.. CAUTION::
+
+ If you have installed ``torchaudio``, please consider uninstalling it before
+ installing ``lhotse``. Otherwise, it may update your already installed PyTorch.
+
(3) Download icefall
--------------------
-icefall is a collection of Python scripts, so you don't need to install it
+``icefall`` is a collection of Python scripts, so you don't need to install it
and we don't provide a ``setup.py`` to install it.
What you need is to download it and set the environment variable ``PYTHONPATH``
@@ -367,7 +380,7 @@ Now let us run the training part:
.. CAUTION::
- We use ``export CUDA_VISIBLE_DEVICES=""`` so that icefall uses CPU
+ We use ``export CUDA_VISIBLE_DEVICES=""`` so that ``icefall`` uses CPU
even if there are GPUs available.
The training log is given below:
diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst
index db34fdca5..36f8dfc39 100644
--- a/docs/source/recipes/index.rst
+++ b/docs/source/recipes/index.rst
@@ -15,4 +15,3 @@ We may add recipes for other tasks as well in the future.
yesno
librispeech
-
diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst
index 50f262a54..40100bc5a 100644
--- a/docs/source/recipes/librispeech/conformer_ctc.rst
+++ b/docs/source/recipes/librispeech/conformer_ctc.rst
@@ -45,7 +45,7 @@ For example,
.. code-block:: bash
- $ cd egs/yesno/ASR
+ $ cd egs/librispeech/ASR
$ ./prepare.sh --stage 0 --stop-stage 0
means to run only stage 0.
@@ -171,7 +171,7 @@ The following options are used quite often:
Pre-configured options
~~~~~~~~~~~~~~~~~~~~~~
-There are some training options, e.g., learning rate,
+There are some training options, e.g., weight decay,
number of warmup steps, results dir, etc,
that are not passed from the commandline.
They are pre-configured by the function ``get_params()`` in
@@ -303,7 +303,7 @@ The commonly used options are:
- ``--lattice-score-scale``
- It is used to scaled down lattice scores so that we can more unique
+ It is used to scale down lattice scores so that there are more unique
paths for rescoring.
- ``--max-duration``
@@ -314,7 +314,7 @@ The commonly used options are:
Pre-trained Model
-----------------
-We have uploaded the pre-trained model to
+We have uploaded a pre-trained model to
``_.
We describe how to use the pre-trained model to transcribe a sound file or
@@ -324,7 +324,7 @@ Install kaldifeat
~~~~~~~~~~~~~~~~~
`kaldifeat `_ is used to
-extract features for a single sound file or multiple soundfiles
+extract features for a single sound file or multiple sound files
at the same time.
Please refer to ``_ for installation.
@@ -346,6 +346,10 @@ The following commands describe how to download the pre-trained model:
You have to use ``git lfs`` to download the pre-trained model.
+.. CAUTION::
+
+ In order to use this pre-trained model, your k2 version has to be v1.7 or later.
+
After downloading, you will have the following files:
.. code-block:: bash
@@ -397,7 +401,7 @@ After downloading, you will have the following files:
- ``data/lm/G_4_gram.pt``
- It is a 4-gram LM, useful for LM rescoring.
+ It is a 4-gram LM, used for n-gram LM rescoring.
- ``exp/pretrained.pt``
@@ -409,9 +413,9 @@ After downloading, you will have the following files:
It contains some test sound files from LibriSpeech ``test-clean`` dataset.
- - `test_waves/trans.txt`
+ - ``test_waves/trans.txt``
- It contains the reference transcripts for the sound files in `test_waves/`.
+ It contains the reference transcripts for the sound files in ``test_waves/``.
The information of the test sound files is listed below:
@@ -556,7 +560,7 @@ Its output is:
HLG decoding + LM rescoring + attention decoder rescoring
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-It uses an n-gram LM to rescore the decoding lattice, extracts
+It uses an n-gram LM to rescore the decoding lattice, extracts
n paths from the rescored lattice, recores the extracted paths with
an attention decoder. The path with the highest score is the decoding result.
diff --git a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst
index a59f34db7..848026802 100644
--- a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst
+++ b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst
@@ -153,10 +153,6 @@ Some commonly used options are:
will save the averaged model to ``tdnn_lstm_ctc/exp/pretrained.pt``.
See :ref:`tdnn_lstm_ctc use a pre-trained model` for how to use it.
-.. HINT::
-
- There are several decoding methods provided in `tdnn_lstm_ctc/decode.py `_, you can change the decoding method by modifying ``method`` parameter in function ``get_params()``.
-
.. _tdnn_lstm_ctc use a pre-trained model:
@@ -168,6 +164,16 @@ We have uploaded the pre-trained model to
The following shows you how to use the pre-trained model.
+
+Install kaldifeat
+~~~~~~~~~~~~~~~~~
+
+`kaldifeat `_ is used to
+extract features for a single sound file or multiple sound files
+at the same time.
+
+Please refer to ``_ for installation.
+
Download the pre-trained model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -183,6 +189,10 @@ Download the pre-trained model
You have to use ``git lfs`` to download the pre-trained model.
+.. CAUTION::
+
+ In order to use this pre-trained model, your k2 version has to be v1.7 or later.
+
After downloading, you will have the following files:
.. code-block:: bash
@@ -209,16 +219,78 @@ After downloading, you will have the following files:
|-- 1221-135766-0001.flac
|-- 1221-135766-0002.flac
`-- trans.txt
-
+
6 directories, 10 files
+**File descriptions**:
-Download kaldifeat
-~~~~~~~~~~~~~~~~~~
+ - ``data/lang_phone/HLG.pt``
+
+ It is the decoding graph.
+
+ - ``data/lang_phone/tokens.txt``
+
+ It contains tokens and their IDs.
+
+ - ``data/lang_phone/words.txt``
+
+ It contains words and their IDs.
+
+ - ``data/lm/G_4_gram.pt``
+
+ It is a 4-gram LM, useful for LM rescoring.
+
+ - ``exp/pretrained.pt``
+
+ It contains pre-trained model parameters, obtained by averaging
+ checkpoints from ``epoch-14.pt`` to ``epoch-19.pt``.
+ Note: We have removed optimizer ``state_dict`` to reduce file size.
+
+ - ``test_waves/*.flac``
+
+ It contains some test sound files from LibriSpeech ``test-clean`` dataset.
+
+ - ``test_waves/trans.txt``
+
+ It contains the reference transcripts for the sound files in ``test_waves/``.
+
+The information of the test sound files is listed below:
+
+.. code-block:: bash
+
+ $ soxi tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/*.flac
+
+ Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac'
+ Channels : 1
+ Sample Rate : 16000
+ Precision : 16-bit
+ Duration : 00:00:06.62 = 106000 samples ~ 496.875 CDDA sectors
+ File Size : 116k
+ Bit Rate : 140k
+ Sample Encoding: 16-bit FLAC
+
+
+ Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac'
+ Channels : 1
+ Sample Rate : 16000
+ Precision : 16-bit
+ Duration : 00:00:16.71 = 267440 samples ~ 1253.62 CDDA sectors
+ File Size : 343k
+ Bit Rate : 164k
+ Sample Encoding: 16-bit FLAC
+
+
+ Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac'
+ Channels : 1
+ Sample Rate : 16000
+ Precision : 16-bit
+ Duration : 00:00:04.83 = 77200 samples ~ 361.875 CDDA sectors
+ File Size : 105k
+ Bit Rate : 174k
+ Sample Encoding: 16-bit FLAC
+
+ Total Duration of 3 files: 00:00:28.16
-`kaldifeat `_ is used for extracting
-features from a single or multiple sound files. Please refer to
-``_ to install ``kaldifeat`` first.
Inference with a pre-trained model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -256,14 +328,14 @@ The output is:
2021-08-24 16:57:28,098 INFO [pretrained.py:266]
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
-
+
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
-
+
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
-
-
+
+
2021-08-24 16:57:28,099 INFO [pretrained.py:268] Decoding Done
@@ -297,14 +369,14 @@ The decoding output is:
2021-08-24 16:39:54,010 INFO [pretrained.py:266]
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
-
+
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
-
+
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
-
-
+
+
2021-08-24 16:39:54,010 INFO [pretrained.py:268] Decoding Done
diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md
index dfc412672..d04e912bf 100644
--- a/egs/librispeech/ASR/RESULTS.md
+++ b/egs/librispeech/ASR/RESULTS.md
@@ -21,6 +21,32 @@ To get more unique paths, we scaled the lattice.scores with 0.5 (see https://git
|test-clean|1.3|1.2|
|test-other|1.2|1.1|
+You can use the following commands to reproduce our results:
+
+```bash
+git clone https://github.com/k2-fsa/icefall
+cd icefall
+
+# It was using ef233486, you may not need to switch to it
+# git checkout ef233486
+
+cd egs/librispeech/ASR
+./prepare.sh
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+python conformer_ctc/train.py --bucketing-sampler True \
+ --concatenate-cuts False \
+ --max-duration 200 \
+ --full-libri True \
+ --world-size 4
+
+python conformer_ctc/decode.py --lattice-score-scale 0.5 \
+ --epoch 34 \
+ --avg 20 \
+ --method attention-decoder \
+ --max-duration 20 \
+ --num-paths 100
+```
### LibriSpeech training results (Tdnn-Lstm)
#### 2021-08-24
@@ -43,4 +69,3 @@ We searched the lm_score_scale for best results, the scales that produced the WE
|--|--|
|test-clean|0.8|
|test-other|0.9|
-
diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py
index 08287d686..b19b94db1 100644
--- a/egs/librispeech/ASR/conformer_ctc/conformer.py
+++ b/egs/librispeech/ASR/conformer_ctc/conformer.py
@@ -56,8 +56,6 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31,
normalize_before: bool = True,
vgg_frontend: bool = False,
- is_espnet_structure: bool = False,
- mmi_loss: bool = True,
use_feat_batchnorm: bool = False,
) -> None:
super(Conformer, self).__init__(
@@ -72,7 +70,6 @@ class Conformer(Transformer):
dropout=dropout,
normalize_before=normalize_before,
vgg_frontend=vgg_frontend,
- mmi_loss=mmi_loss,
use_feat_batchnorm=use_feat_batchnorm,
)
@@ -85,12 +82,10 @@ class Conformer(Transformer):
dropout,
cnn_module_kernel,
normalize_before,
- is_espnet_structure,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.normalize_before = normalize_before
- self.is_espnet_structure = is_espnet_structure
- if self.normalize_before and self.is_espnet_structure:
+ if self.normalize_before:
self.after_norm = nn.LayerNorm(d_model)
else:
# Note: TorchScript detects that self.after_norm could be used inside forward()
@@ -103,7 +98,7 @@ class Conformer(Transformer):
"""
Args:
x:
- The model input. Its shape is [N, T, C].
+ The model input. Its shape is (N, T, C).
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
@@ -125,7 +120,7 @@ class Conformer(Transformer):
mask = mask.to(x.device)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
- if self.normalize_before and self.is_espnet_structure:
+ if self.normalize_before:
x = self.after_norm(x)
return x, mask
@@ -159,11 +154,10 @@ class ConformerEncoderLayer(nn.Module):
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
- is_espnet_structure: bool = False,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(
- d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure
+ d_model, nhead, dropout=0.0
)
self.feed_forward = nn.Sequential(
@@ -436,7 +430,6 @@ class RelPositionMultiheadAttention(nn.Module):
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
- is_espnet_structure: bool = False,
) -> None:
super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
@@ -459,8 +452,6 @@ class RelPositionMultiheadAttention(nn.Module):
self._reset_parameters()
- self.is_espnet_structure = is_espnet_structure
-
def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.constant_(self.in_proj.bias, 0.0)
@@ -690,9 +681,6 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:]
v = nn.functional.linear(value, _w, _b)
- if not self.is_espnet_structure:
- q = q * scaling
-
if attn_mask is not None:
assert (
attn_mask.dtype == torch.float32
@@ -785,14 +773,9 @@ class RelPositionMultiheadAttention(nn.Module):
) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd)
- if not self.is_espnet_structure:
- attn_output_weights = (
- matrix_ac + matrix_bd
- ) # (batch, head, time1, time2)
- else:
- attn_output_weights = (
- matrix_ac + matrix_bd
- ) * scaling # (batch, head, time1, time2)
+ attn_output_weights = (
+ matrix_ac + matrix_bd
+ ) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1
diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py
index ff6374d73..b5b41c82e 100755
--- a/egs/librispeech/ASR/conformer_ctc/decode.py
+++ b/egs/librispeech/ASR/conformer_ctc/decode.py
@@ -45,6 +45,7 @@ from icefall.utils import (
get_texts,
setup_logger,
store_transcripts,
+ str2bool,
write_error_stats,
)
@@ -107,7 +108,7 @@ def get_parser():
parser.add_argument(
"--lattice-score-scale",
type=float,
- default=1.0,
+ default=0.5,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
@@ -116,6 +117,17 @@ def get_parser():
""",
)
+ parser.add_argument(
+ "--export",
+ type=str2bool,
+ default=False,
+ help="""When enabled, the averaged model is saved to
+ conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
+ pretrained.pt contains a dict {"model": model.state_dict()},
+ which can be loaded by `icefall.checkpoint.load_checkpoint()`.
+ """,
+ )
+
return parser
@@ -125,15 +137,15 @@ def get_params() -> AttributeDict:
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_bpe"),
"lm_dir": Path("data/lm"),
+ # parameters for conformer
+ "subsampling_factor": 4,
+ "vgg_frontend": False,
+ "use_feat_batchnorm": True,
"feature_dim": 80,
"nhead": 8,
"attention_dim": 512,
- "subsampling_factor": 4,
"num_decoder_layers": 6,
- "vgg_frontend": False,
- "is_espnet_structure": True,
- "mmi_loss": False,
- "use_feat_batchnorm": True,
+ # parameters for decoding
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
@@ -201,12 +213,12 @@ def decode_one_batch(
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
- # at entry, feature is [N, T, C]
+ # at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
- # nnet_output is [N, T, C]
+ # nnet_output is (N, T, C)
supervision_segments = torch.stack(
(
@@ -232,14 +244,19 @@ def decode_one_batch(
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
- # is slightly worse than that of rescored lattices.
- return nbest_oracle(
+ # is only slightly worse than that of rescored lattices.
+ best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
- scale=params.lattice_score_scale,
+ lattice_score_scale=params.lattice_score_scale,
+ oov="",
)
+ hyps = get_texts(best_path)
+ hyps = [[word_table[i] for i in ids] for ids in hyps]
+ key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa
+ return {key: hyps}
if params.method in ["1best", "nbest"]:
if params.method == "1best":
@@ -252,7 +269,7 @@ def decode_one_batch(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
- scale=params.lattice_score_scale,
+ lattice_score_scale=params.lattice_score_scale,
)
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa
@@ -266,7 +283,8 @@ def decode_one_batch(
"attention-decoder",
]
- lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
+ lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
+ lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.method == "nbest-rescoring":
@@ -275,17 +293,23 @@ def decode_one_batch(
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
- scale=params.lattice_score_scale,
+ lattice_score_scale=params.lattice_score_scale,
)
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
- lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
+ lattice=lattice,
+ G_with_epsilon_loops=G,
+ lm_scale_list=lm_scale_list,
)
elif params.method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
- lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
+ lattice=lattice,
+ G_with_epsilon_loops=G,
+ lm_scale_list=None,
)
+ # TODO: pass `lattice` instead of `rescored_lattice` to
+ # `rescore_with_attention_decoder`
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
@@ -295,16 +319,20 @@ def decode_one_batch(
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
- scale=params.lattice_score_scale,
+ lattice_score_scale=params.lattice_score_scale,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
ans = dict()
- for lm_scale_str, best_path in best_path_dict.items():
- hyps = get_texts(best_path)
- hyps = [[word_table[i] for i in ids] for ids in hyps]
- ans[lm_scale_str] = hyps
+ if best_path_dict is not None:
+ for lm_scale_str, best_path in best_path_dict.items():
+ hyps = get_texts(best_path)
+ hyps = [[word_table[i] for i in ids] for ids in hyps]
+ ans[lm_scale_str] = hyps
+ else:
+ for lm_scale in lm_scale_list:
+ ans[lm_scale_str] = [[] * lattice.shape[0]]
return ans
@@ -525,8 +553,6 @@ def main():
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
- is_espnet_structure=params.is_espnet_structure,
- mmi_loss=params.mmi_loss,
use_feat_batchnorm=params.use_feat_batchnorm,
)
@@ -541,6 +567,13 @@ def main():
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
+ if params.export:
+ logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
+ torch.save(
+ {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+ )
+ return
+
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py
index 95029fadb..c924b87bb 100755
--- a/egs/librispeech/ASR/conformer_ctc/pretrained.py
+++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py
@@ -173,17 +173,17 @@ def get_parser():
def get_params() -> AttributeDict:
params = AttributeDict(
{
+ "sample_rate": 16000,
+ # parameters for conformer
+ "subsampling_factor": 4,
+ "vgg_frontend": False,
+ "use_feat_batchnorm": True,
"feature_dim": 80,
"nhead": 8,
"num_classes": 5000,
- "sample_rate": 16000,
"attention_dim": 512,
- "subsampling_factor": 4,
"num_decoder_layers": 6,
- "vgg_frontend": False,
- "is_espnet_structure": True,
- "mmi_loss": False,
- "use_feat_batchnorm": True,
+ # parameters for decoding
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
@@ -241,8 +241,6 @@ def main():
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
- is_espnet_structure=params.is_espnet_structure,
- mmi_loss=params.mmi_loss,
use_feat_batchnorm=params.use_feat_batchnorm,
)
@@ -338,7 +336,7 @@ def main():
memory_key_padding_mask=memory_key_padding_mask,
sos_id=params.sos_id,
eos_id=params.eos_id,
- scale=params.lattice_score_scale,
+ lattice_score_scale=params.lattice_score_scale,
ngram_lm_scale=params.ngram_lm_scale,
attention_scale=params.attention_decoder_scale,
)
diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py
index 720ed6c22..542fb0364 100644
--- a/egs/librispeech/ASR/conformer_ctc/subsampling.py
+++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py
@@ -22,8 +22,8 @@ import torch.nn as nn
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
- Convert an input of shape [N, T, idim] to an output
- with shape [N, T', odim], where
+ Convert an input of shape (N, T, idim) to an output
+ with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
It is based on
@@ -34,10 +34,10 @@ class Conv2dSubsampling(nn.Module):
"""
Args:
idim:
- Input dim. The input shape is [N, T, idim].
+ Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7
odim:
- Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
+ Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
"""
assert idim >= 7
super().__init__()
@@ -58,18 +58,18 @@ class Conv2dSubsampling(nn.Module):
Args:
x:
- Its shape is [N, T, idim].
+ Its shape is (N, T, idim).
Returns:
- Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
+ Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
- # On entry, x is [N, T, idim]
- x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W]
+ # On entry, x is (N, T, idim)
+ x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
- # Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2]
+ # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
- # Now x is of shape [N, ((T-1)//2 - 1))//2, odim]
+ # Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
return x
@@ -80,8 +80,8 @@ class VggSubsampling(nn.Module):
This paper is not 100% explicit so I am guessing to some extent,
and trying to compare with other VGG implementations.
- Convert an input of shape [N, T, idim] to an output
- with shape [N, T', odim], where
+ Convert an input of shape (N, T, idim) to an output
+ with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
"""
@@ -93,10 +93,10 @@ class VggSubsampling(nn.Module):
Args:
idim:
- Input dim. The input shape is [N, T, idim].
+ Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7
odim:
- Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
+ Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
"""
super().__init__()
@@ -149,10 +149,10 @@ class VggSubsampling(nn.Module):
Args:
x:
- Its shape is [N, T, idim].
+ Its shape is (N, T, idim).
Returns:
- Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
+ Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
x = x.unsqueeze(1)
x = self.layers(x)
diff --git a/egs/librispeech/ASR/conformer_ctc/test_subsampling.py b/egs/librispeech/ASR/conformer_ctc/test_subsampling.py
index e3361d0c9..81fa234dd 100755
--- a/egs/librispeech/ASR/conformer_ctc/test_subsampling.py
+++ b/egs/librispeech/ASR/conformer_ctc/test_subsampling.py
@@ -16,9 +16,8 @@
# limitations under the License.
-from subsampling import Conv2dSubsampling
-from subsampling import VggSubsampling
import torch
+from subsampling import Conv2dSubsampling, VggSubsampling
def test_conv2d_subsampling():
diff --git a/egs/librispeech/ASR/conformer_ctc/test_transformer.py b/egs/librispeech/ASR/conformer_ctc/test_transformer.py
index b90215274..667057c51 100644
--- a/egs/librispeech/ASR/conformer_ctc/test_transformer.py
+++ b/egs/librispeech/ASR/conformer_ctc/test_transformer.py
@@ -17,17 +17,16 @@
import torch
+from torch.nn.utils.rnn import pad_sequence
from transformer import (
Transformer,
+ add_eos,
+ add_sos,
+ decoder_padding_mask,
encoder_padding_mask,
generate_square_subsequent_mask,
- decoder_padding_mask,
- add_sos,
- add_eos,
)
-from torch.nn.utils.rnn import pad_sequence
-
def test_encoder_padding_mask():
supervisions = {
diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py
index b0dbe72ad..80b2d924a 100755
--- a/egs/librispeech/ASR/conformer_ctc/train.py
+++ b/egs/librispeech/ASR/conformer_ctc/train.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
-# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@@ -111,15 +112,6 @@ def get_params() -> AttributeDict:
- lang_dir: It contains language related input files such as
"lexicon.txt"
- - lr: It specifies the initial learning rate
-
- - feature_dim: The model input dim. It has to match the one used
- in computing features.
-
- - weight_decay: The weight_decay for the optimizer.
-
- - subsampling_factor: The subsampling factor for the model.
-
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
@@ -138,23 +130,40 @@ def get_params() -> AttributeDict:
- log_interval: Print training loss if batch_idx % log_interval` is 0
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
- valid_interval: Run validation if batch_idx % valid_interval is 0
- - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - use_feat_batchnorm: Whether to do batch normalization for the
+ input features.
+
+ - attention_dim: Hidden dim for multi-head attention model.
+
+ - head: Number of heads of multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
- beam_size: It is used in k2.ctc_loss
- reduction: It is used in k2.ctc_loss
- use_double_scores: It is used in k2.ctc_loss
+
+ - weight_decay: The weight_decay for the optimizer.
+
+ - lr_factor: The lr_factor for Noam optimizer.
+
+ - warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_bpe"),
- "feature_dim": 80,
- "weight_decay": 1e-6,
- "subsampling_factor": 4,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
@@ -163,17 +172,20 @@ def get_params() -> AttributeDict:
"log_interval": 10,
"reset_interval": 200,
"valid_interval": 3000,
- "beam_size": 10,
- "reduction": "sum",
- "use_double_scores": True,
- "accum_grad": 1,
- "att_rate": 0.7,
+ # parameters for conformer
+ "feature_dim": 80,
+ "subsampling_factor": 4,
+ "use_feat_batchnorm": True,
"attention_dim": 512,
"nhead": 8,
"num_decoder_layers": 6,
- "is_espnet_structure": True,
- "mmi_loss": False,
- "use_feat_batchnorm": True,
+ # parameters for loss
+ "beam_size": 10,
+ "reduction": "sum",
+ "use_double_scores": True,
+ "att_rate": 0.7,
+ # parameters for Noam
+ "weight_decay": 1e-6,
"lr_factor": 5.0,
"warm_step": 80000,
}
@@ -298,14 +310,14 @@ def compute_loss(
"""
device = graph_compiler.device
feature = batch["inputs"]
- # at entry, feature is [N, T, C]
+ # at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
- # nnet_output is [N, T, C]
+ # nnet_output is (N, T, C)
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
@@ -646,8 +658,6 @@ def run(rank, world_size, args):
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,
- is_espnet_structure=params.is_espnet_structure,
- mmi_loss=params.mmi_loss,
use_feat_batchnorm=params.use_feat_batchnorm,
)
diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py
index 191d2d612..68a4ff65c 100644
--- a/egs/librispeech/ASR/conformer_ctc/transformer.py
+++ b/egs/librispeech/ASR/conformer_ctc/transformer.py
@@ -41,7 +41,6 @@ class Transformer(nn.Module):
dropout: float = 0.1,
normalize_before: bool = True,
vgg_frontend: bool = False,
- mmi_loss: bool = True,
use_feat_batchnorm: bool = False,
) -> None:
"""
@@ -70,7 +69,6 @@ class Transformer(nn.Module):
If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend:
True to use vgg style frontend for subsampling.
- mmi_loss:
use_feat_batchnorm:
True to use batchnorm for the input layer.
"""
@@ -85,8 +83,8 @@ class Transformer(nn.Module):
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
- # self.encoder_embed converts the input of shape [N, T, num_classes]
- # to the shape [N, T//subsampling_factor, d_model].
+ # self.encoder_embed converts the input of shape (N, T, num_classes)
+ # to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_classes -> d_model
@@ -122,14 +120,9 @@ class Transformer(nn.Module):
)
if num_decoder_layers > 0:
- if mmi_loss:
- self.decoder_num_class = (
- self.num_classes + 1
- ) # +1 for the sos/eos symbol
- else:
- self.decoder_num_class = (
- self.num_classes
- ) # bpe model already has sos/eos symbol
+ self.decoder_num_class = (
+ self.num_classes
+ ) # bpe model already has sos/eos symbol
self.decoder_embed = nn.Embedding(
num_embeddings=self.decoder_num_class, embedding_dim=d_model
@@ -169,7 +162,7 @@ class Transformer(nn.Module):
"""
Args:
x:
- The input tensor. Its shape is [N, T, C].
+ The input tensor. Its shape is (N, T, C).
supervision:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
@@ -178,17 +171,17 @@ class Transformer(nn.Module):
Returns:
Return a tuple containing 3 tensors:
- - CTC output for ctc decoding. Its shape is [N, T, C]
- - Encoder output with shape [T, N, C]. It can be used as key and
+ - CTC output for ctc decoding. Its shape is (N, T, C)
+ - Encoder output with shape (T, N, C). It can be used as key and
value for the decoder.
- Encoder output padding mask. It can be used as
- memory_key_padding_mask for the decoder. Its shape is [N, T].
+ memory_key_padding_mask for the decoder. Its shape is (N, T).
It is None if `supervision` is None.
"""
if self.use_feat_batchnorm:
- x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
+ x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
- x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
+ x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision
)
@@ -202,7 +195,7 @@ class Transformer(nn.Module):
Args:
x:
- The model input. Its shape is [N, T, C].
+ The model input. Its shape is (N, T, C).
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
@@ -213,8 +206,8 @@ class Transformer(nn.Module):
padding mask for the decoder.
Returns:
Return a tuple with two tensors:
- - The encoder output, with shape [T, N, C]
- - encoder padding mask, with shape [N, T].
+ - The encoder output, with shape (T, N, C)
+ - encoder padding mask, with shape (N, T).
The mask is None if `supervisions` is None.
It is used as memory key padding mask in the decoder.
"""
@@ -232,11 +225,11 @@ class Transformer(nn.Module):
Args:
x:
The output tensor from the transformer encoder.
- Its shape is [T, N, C]
+ Its shape is (T, N, C)
Returns:
Return a tensor that can be used for CTC decoding.
- Its shape is [N, T, C]
+ Its shape is (N, T, C)
"""
x = self.encoder_output_layer(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@@ -254,7 +247,7 @@ class Transformer(nn.Module):
"""
Args:
memory:
- It's the output of the encoder with shape [T, N, C]
+ It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask:
The padding mask from the encoder.
token_ids:
@@ -319,7 +312,7 @@ class Transformer(nn.Module):
"""
Args:
memory:
- It's the output of the encoder with shape [T, N, C]
+ It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask:
The padding mask from the encoder.
token_ids:
@@ -661,13 +654,13 @@ class PositionalEncoding(nn.Module):
def extend_pe(self, x: torch.Tensor) -> None:
"""Extend the time t in the positional encoding if required.
- The shape of `self.pe` is [1, T1, d_model]. The shape of the input x
- is [N, T, d_model]. If T > T1, then we change the shape of self.pe
- to [N, T, d_model]. Otherwise, nothing is done.
+ The shape of `self.pe` is (1, T1, d_model). The shape of the input x
+ is (N, T, d_model). If T > T1, then we change the shape of self.pe
+ to (N, T, d_model). Otherwise, nothing is done.
Args:
x:
- It is a tensor of shape [N, T, C].
+ It is a tensor of shape (N, T, C).
Returns:
Return None.
"""
@@ -685,7 +678,7 @@ class PositionalEncoding(nn.Module):
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
- # Now pe is of shape [1, T, d_model], where T is x.size(1)
+ # Now pe is of shape (1, T, d_model), where T is x.size(1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -694,10 +687,10 @@ class PositionalEncoding(nn.Module):
Args:
x:
- Its shape is [N, T, C]
+ Its shape is (N, T, C)
Returns:
- Return a tensor of shape [N, T, C]
+ Return a tensor of shape (N, T, C)
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1), :]
diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py
index 19a1ddd23..098d5d6a3 100755
--- a/egs/librispeech/ASR/local/compile_hlg.py
+++ b/egs/librispeech/ASR/local/compile_hlg.py
@@ -102,14 +102,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
LG.labels[LG.labels >= first_token_disambig_id] = 0
- assert isinstance(LG.aux_labels, k2.RaggedInt)
- LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0
+ assert isinstance(LG.aux_labels, k2.RaggedTensor)
+ LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
- LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)
+ LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/Pre-trained.md b/egs/librispeech/ASR/tdnn_lstm_ctc/Pre-trained.md
deleted file mode 100644
index 83e98b37c..000000000
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/Pre-trained.md
+++ /dev/null
@@ -1,270 +0,0 @@
-
-# How to use a pre-trained model to transcribe a sound file or multiple sound files
-
-(See the bottom of this document for the link to a colab notebook.)
-
-You need to prepare 4 files:
-
- - a model checkpoint file, e.g., epoch-20.pt
- - HLG.pt, the decoding graph
- - words.txt, the word symbol table
- - a sound file, whose sampling rate has to be 16 kHz.
- Supported formats are those supported by `torchaudio.load()`,
- e.g., wav and flac.
-
-Also, you need to install `kaldifeat`. Please refer to
- for installation.
-
-```bash
-./tdnn_lstm_ctc/pretrained.py --help
-```
-
-displays the help information.
-
-## HLG decoding
-
-Once you have the above files ready and have `kaldifeat` installed,
-you can run:
-
-```bash
-./tdnn_lstm_ctc/pretrained.py \
- --checkpoint /path/to/your/checkpoint.pt \
- --words-file /path/to/words.txt \
- --HLG /path/to/HLG.pt \
- /path/to/your/sound.wav
-```
-
-and you will see the transcribed result.
-
-If you want to transcribe multiple files at the same time, you can use:
-
-```bash
-./tdnn_lstm_ctc/pretrained.py \
- --checkpoint /path/to/your/checkpoint.pt \
- --words-file /path/to/words.txt \
- --HLG /path/to/HLG.pt \
- /path/to/your/sound1.wav \
- /path/to/your/sound2.wav \
- /path/to/your/sound3.wav
-```
-
-**Note**: This is the fastest decoding method.
-
-## HLG decoding + LM rescoring
-
-`./tdnn_lstm_ctc/pretrained.py` also supports `whole lattice LM rescoring`.
-
-To use whole lattice LM rescoring, you also need the following files:
-
- - G.pt, e.g., `data/lm/G_4_gram.pt` if you have run `./prepare.sh`
-
-The command to run decoding with LM rescoring is:
-
-```bash
-./tdnn_lstm_ctc/pretrained.py \
- --checkpoint /path/to/your/checkpoint.pt \
- --words-file /path/to/words.txt \
- --HLG /path/to/HLG.pt \
- --method whole-lattice-rescoring \
- --G data/lm/G_4_gram.pt \
- --ngram-lm-scale 0.8 \
- /path/to/your/sound1.wav \
- /path/to/your/sound2.wav \
- /path/to/your/sound3.wav
-```
-
-# Decoding with a pre-trained model in action
-
-We have uploaded a pre-trained model to
-
-The following shows the steps about the usage of the provided pre-trained model.
-
-### (1) Download the pre-trained model
-
-```bash
-sudo apt-get install git-lfs
-cd /path/to/icefall/egs/librispeech/ASR
-git lfs install
-mkdir tmp
-cd tmp
-git clone https://huggingface.co/pkufool/icefall_asr_librispeech_tdnn-lstm_ctc
-```
-
-**CAUTION**: You have to install `git-lfs` to download the pre-trained model.
-
-You will find the following files:
-
-```
-tmp/
-`-- icefall_asr_librispeech_tdnn-lstm_ctc
- |-- README.md
- |-- data
- | |-- lang_phone
- | | |-- HLG.pt
- | | |-- tokens.txt
- | | `-- words.txt
- | `-- lm
- | `-- G_4_gram.pt
- |-- exp
- | `-- pretrained.pt
- `-- test_wavs
- |-- 1089-134686-0001.flac
- |-- 1221-135766-0001.flac
- |-- 1221-135766-0002.flac
- `-- trans.txt
-
-6 directories, 10 files
-```
-
-**File descriptions**:
-
- - `data/lang_phone/HLG.pt`
-
- It is the decoding graph.
-
- - `data/lang_phone/tokens.txt`
-
- It contains tokens and their IDs.
-
- - `data/lang_phone/words.txt`
-
- It contains words and their IDs.
-
- - `data/lm/G_4_gram.pt`
-
- It is a 4-gram LM, useful for LM rescoring.
-
- - `exp/pretrained.pt`
-
- It contains pre-trained model parameters, obtained by averaging
- checkpoints from `epoch-14.pt` to `epoch-19.pt`.
- Note: We have removed optimizer `state_dict` to reduce file size.
-
- - `test_waves/*.flac`
-
- It contains some test sound files from LibriSpeech `test-clean` dataset.
-
- - `test_waves/trans.txt`
-
- It contains the reference transcripts for the sound files in `test_waves/`.
-
-The information of the test sound files is listed below:
-
-```
-$ soxi tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/*.flac
-
-Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac'
-Channels : 1
-Sample Rate : 16000
-Precision : 16-bit
-Duration : 00:00:06.62 = 106000 samples ~ 496.875 CDDA sectors
-File Size : 116k
-Bit Rate : 140k
-Sample Encoding: 16-bit FLAC
-
-
-Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac'
-Channels : 1
-Sample Rate : 16000
-Precision : 16-bit
-Duration : 00:00:16.71 = 267440 samples ~ 1253.62 CDDA sectors
-File Size : 343k
-Bit Rate : 164k
-Sample Encoding: 16-bit FLAC
-
-
-Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac'
-Channels : 1
-Sample Rate : 16000
-Precision : 16-bit
-Duration : 00:00:04.83 = 77200 samples ~ 361.875 CDDA sectors
-File Size : 105k
-Bit Rate : 174k
-Sample Encoding: 16-bit FLAC
-
-Total Duration of 3 files: 00:00:28.16
-```
-
-### (2) Use HLG decoding
-
-```bash
-cd /path/to/icefall/egs/librispeech/ASR
-
-./tdnn_lstm_ctc/pretrained.py \
- --checkpoint ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretraind.pt \
- --words-file ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/words.txt \
- --HLG ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt \
- ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac \
- ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac \
- ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac
-```
-
-The output is given below:
-
-```
-2021-08-24 16:57:13,315 INFO [pretrained.py:168] device: cuda:0
-2021-08-24 16:57:13,315 INFO [pretrained.py:170] Creating model
-2021-08-24 16:57:18,331 INFO [pretrained.py:182] Loading HLG from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt
-2021-08-24 16:57:27,581 INFO [pretrained.py:199] Constructing Fbank computer
-2021-08-24 16:57:27,584 INFO [pretrained.py:209] Reading sound files: ['./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac']
-2021-08-24 16:57:27,599 INFO [pretrained.py:215] Decoding started
-2021-08-24 16:57:27,791 INFO [pretrained.py:245] Use HLG decoding
-2021-08-24 16:57:28,098 INFO [pretrained.py:266]
-./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac:
-AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
-
-./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac:
-GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
-
-./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac:
-YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
-
-
-2021-08-24 16:57:28,099 INFO [pretrained.py:268] Decoding Done
-```
-
-### (3) Use HLG decoding + LM rescoring
-
-```bash
-./tdnn_lstm_ctc/pretrained.py \
- --checkpoint ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretraind.pt \
- --words-file ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/words.txt \
- --HLG ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt \
- --method whole-lattice-rescoring \
- --G ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lm/G_4_gram.pt \
- --ngram-lm-scale 0.8 \
- ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac \
- ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac \
- ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac
-```
-
-The output is:
-
-```
-2021-08-24 16:39:24,725 INFO [pretrained.py:168] device: cuda:0
-2021-08-24 16:39:24,725 INFO [pretrained.py:170] Creating model
-2021-08-24 16:39:29,403 INFO [pretrained.py:182] Loading HLG from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt
-2021-08-24 16:39:40,631 INFO [pretrained.py:190] Loading G from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lm/G_4_gram.pt
-2021-08-24 16:39:53,098 INFO [pretrained.py:199] Constructing Fbank computer
-2021-08-24 16:39:53,107 INFO [pretrained.py:209] Reading sound files: ['./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac']
-2021-08-24 16:39:53,121 INFO [pretrained.py:215] Decoding started
-2021-08-24 16:39:53,443 INFO [pretrained.py:250] Use HLG decoding + LM rescoring
-2021-08-24 16:39:54,010 INFO [pretrained.py:266]
-./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac:
-AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
-
-./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac:
-GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
-
-./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac:
-YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
-
-
-2021-08-24 16:39:54,010 INFO [pretrained.py:268] Decoding Done
-```
-
-**NOTE**: We provide a colab notebook for demonstration.
-[](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing)
-
-Due to limited memory provided by Colab, you have to upgrade to Colab Pro to run `HLG decoding + LM rescoring`.
-Otherwise, you can only run `HLG decoding` with Colab.
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 91c1d6a96..8290e71d1 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -82,14 +82,14 @@ class LibriSpeechAsrDataModule(DataModule):
group.add_argument(
"--max-duration",
type=int,
- default=500.0,
+ default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
- default=False,
+ default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
index afdebd12b..1e91b1008 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
@@ -67,6 +67,47 @@ def get_parser():
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="whole-lattice-rescoring",
+ help="""Decoding method.
+ Supported values are:
+ - (1) 1best. Extract the best path from the decoding lattice as the
+ decoding result.
+ - (2) nbest. Extract n paths from the decoding lattice; the path
+ with the highest score is the decoding result.
+ - (3) nbest-rescoring. Extract n paths from the decoding lattice,
+ rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
+ the highest score is the decoding result.
+ - (4) whole-lattice-rescoring. Rescore the decoding lattice with an
+ n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
+ is the decoding result.
+ """,
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=100,
+ help="""Number of paths for n-best based decoding method.
+ Used only when "method" is one of the following values:
+ nbest, nbest-rescoring
+ """,
+ )
+
+ parser.add_argument(
+ "--lattice-score-scale",
+ type=float,
+ default=0.5,
+ help="""The scale to be applied to `lattice.scores`.
+ It's needed if you use any kinds of n-best based rescoring.
+ Used only when "method" is one of the following values:
+ nbest, nbest-rescoring
+ A smaller value results in more unique paths.
+ """,
+ )
+
parser.add_argument(
"--export",
type=str2bool,
@@ -93,14 +134,6 @@ def get_params() -> AttributeDict:
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
- # Possible values for method:
- # - 1best
- # - nbest
- # - nbest-rescoring
- # - whole-lattice-rescoring
- "method": "whole-lattice-rescoring",
- # num_paths is used when method is "nbest" and "nbest-rescoring"
- "num_paths": 30,
}
)
return params
@@ -157,12 +190,12 @@ def decode_one_batch(
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
- # at entry, feature is [N, T, C]
+ # at entry, feature is (N, T, C)
- feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
+ feature = feature.permute(0, 2, 1) # now feature is (N, C, T)
nnet_output = model(feature)
- # nnet_output is [N, T, C]
+ # nnet_output is (N, T, C)
supervisions = batch["supervisions"]
@@ -196,6 +229,7 @@ def decode_one_batch(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
+ lattice_score_scale=params.lattice_score_scale,
)
key = f"no_rescore-{params.num_paths}"
hyps = get_texts(best_path)
@@ -204,7 +238,8 @@ def decode_one_batch(
assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"]
- lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
+ lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
+ lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.method == "nbest-rescoring":
@@ -213,10 +248,13 @@ def decode_one_batch(
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
+ lattice_score_scale=params.lattice_score_scale,
)
else:
best_path_dict = rescore_with_whole_lattice(
- lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
+ lattice=lattice,
+ G_with_epsilon_loops=G,
+ lm_scale_list=lm_scale_list,
)
ans = dict()
@@ -424,6 +462,7 @@ def main():
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
+ return
model.to(device)
model.eval()
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
old mode 100644
new mode 100755
index 4f82a989c..0a543d859
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
@@ -218,11 +218,11 @@ def main():
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
- features = features.permute(0, 2, 1) # now features is [N, C, T]
+ features = features.permute(0, 2, 1) # now features is (N, C, T)
with torch.no_grad():
nnet_output = model(features)
- # nnet_output is [N, T, C]
+ # nnet_output is (N, T, C)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
index 4d45d197b..695ee5130 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
@@ -290,14 +290,14 @@ def compute_loss(
"""
device = graph_compiler.device
feature = batch["inputs"]
- # at entry, feature is [N, T, C]
- feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
+ # at entry, feature is (N, T, C)
+ feature = feature.permute(0, 2, 1) # now feature is (N, C, T)
assert feature.ndim == 3
feature = feature.to(device)
with torch.set_grad_enabled(is_training):
nnet_output = model(feature)
- # nnet_output is [N, T, C]
+ # nnet_output is (N, T, C)
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py
index f2fafd013..9b6a4c5ba 100755
--- a/egs/yesno/ASR/local/compile_hlg.py
+++ b/egs/yesno/ASR/local/compile_hlg.py
@@ -80,14 +80,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
LG.labels[LG.labels >= first_token_disambig_id] = 0
- assert isinstance(LG.aux_labels, k2.RaggedInt)
- LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0
+ assert isinstance(LG.aux_labels, k2.RaggedTensor)
+ LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
- LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)
+ LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py
index aa7b07b98..325acf316 100755
--- a/egs/yesno/ASR/tdnn/decode.py
+++ b/egs/yesno/ASR/tdnn/decode.py
@@ -111,10 +111,10 @@ def decode_one_batch(
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
- # at entry, feature is [N, T, C]
+ # at entry, feature is (N, T, C)
nnet_output = model(feature)
- # nnet_output is [N, T, C]
+ # nnet_output is (N, T, C)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
@@ -296,6 +296,7 @@ def main():
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
+ return
model.to(device)
model.eval()
diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py
index 39c5ef3ef..0f5506d38 100755
--- a/egs/yesno/ASR/tdnn/train.py
+++ b/egs/yesno/ASR/tdnn/train.py
@@ -268,13 +268,13 @@ def compute_loss(
"""
device = graph_compiler.device
feature = batch["inputs"]
- # at entry, feature is [N, T, C]
+ # at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
with torch.set_grad_enabled(is_training):
nnet_output = model(feature)
- # nnet_output is [N, T, C]
+ # nnet_output is (N, T, C)
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
diff --git a/icefall/decode.py b/icefall/decode.py
index de3219401..e678e4622 100644
--- a/icefall/decode.py
+++ b/icefall/decode.py
@@ -15,42 +15,12 @@
# limitations under the License.
import logging
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Dict, List, Optional, Union
import k2
-import kaldialign
import torch
-import torch.nn as nn
-
-def _get_random_paths(
- lattice: k2.Fsa,
- num_paths: int,
- use_double_scores: bool = True,
- scale: float = 1.0,
-):
- """
- Args:
- lattice:
- The decoding lattice, returned by :func:`get_lattice`.
- num_paths:
- It specifies the size `n` in n-best. Note: Paths are selected randomly
- and those containing identical word sequences are remove dand only one
- of them is kept.
- use_double_scores:
- True to use double precision floating point in the computation.
- False to use single precision.
- scale:
- It's the scale applied to the lattice.scores. A smaller value
- yields more unique paths.
- Returns:
- Return a k2.RaggedInt with 3 axes [seq][path][arc_pos]
- """
- saved_scores = lattice.scores.clone()
- lattice.scores *= scale
- path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
- lattice.scores = saved_scores
- return path
+from icefall.utils import get_texts
def _intersect_device(
@@ -65,7 +35,7 @@ def _intersect_device(
CUDA OOM error.
The arguments and return value of this function are the same as
- k2.intersect_device.
+ :func:`k2.intersect_device`.
"""
num_fsas = b_fsas.shape[0]
if num_fsas <= batch_size:
@@ -84,8 +54,8 @@ def _intersect_device(
for start, end in splits:
indexes = torch.arange(start, end).to(b_to_a_map)
- fsas = k2.index(b_fsas, indexes)
- b_to_a = k2.index(b_to_a_map, indexes)
+ fsas = k2.index_fsa(b_fsas, indexes)
+ b_to_a = k2.index_select(b_to_a_map, indexes)
path_lattice = k2.intersect_device(
a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a
)
@@ -106,10 +76,9 @@ def get_lattice(
) -> k2.Fsa:
"""Get the decoding lattice from a decoding graph and neural
network output.
-
Args:
nnet_output:
- It is the output of a neural model of shape `[N, T, C]`.
+ It is the output of a neural model of shape `(N, T, C)`.
HLG:
An Fsa, the decoding graph. See also `compile_HLG.py`.
supervision_segments:
@@ -139,10 +108,12 @@ def get_lattice(
subsampling_factor:
The subsampling factor of the model.
Returns:
- A lattice containing the decoding result.
+ An FsaVec containing the decoding result. It has axes [utt][state][arc].
"""
dense_fsa_vec = k2.DenseFsaVec(
- nnet_output, supervision_segments, allow_truncate=subsampling_factor - 1
+ nnet_output,
+ supervision_segments,
+ allow_truncate=subsampling_factor - 1,
)
lattice = k2.intersect_dense_pruned(
@@ -157,8 +128,304 @@ def get_lattice(
return lattice
+class Nbest(object):
+ """
+ An Nbest object contains two fields:
+
+ (1) fsa. It is an FsaVec containing a vector of **linear** FSAs.
+ Its axes are [path][state][arc]
+ (2) shape. Its type is :class:`k2.RaggedShape`.
+ Its axes are [utt][path]
+
+ The field `shape` has two axes [utt][path]. `shape.dim0` contains
+ the number of utterances, which is also the number of rows in the
+ supervision_segments. `shape.tot_size(1)` contains the number
+ of paths, which is also the number of FSAs in `fsa`.
+
+ Caution:
+ Don't be confused by the name `Nbest`. The best in the name `Nbest`
+ has nothing to do with `best scores`. The important part is
+ `N` in `Nbest`, not `best`.
+ """
+
+ def __init__(self, fsa: k2.Fsa, shape: k2.RaggedShape) -> None:
+ """
+ Args:
+ fsa:
+ An FsaVec with axes [path][state][arc]. It is expected to contain
+ a list of **linear** FSAs.
+ shape:
+ A ragged shape with two axes [utt][path].
+ """
+ assert len(fsa.shape) == 3, f"fsa.shape: {fsa.shape}"
+ assert shape.num_axes == 2, f"num_axes: {shape.num_axes}"
+
+ if fsa.shape[0] != shape.tot_size(1):
+ raise ValueError(
+ f"{fsa.shape[0]} vs {shape.tot_size(1)}\n"
+ "Number of FSAs in `fsa` does not match the given shape"
+ )
+
+ self.fsa = fsa
+ self.shape = shape
+
+ def __str__(self):
+ s = "Nbest("
+ s += f"Number of utterances:{self.shape.dim0}, "
+ s += f"Number of Paths:{self.fsa.shape[0]})"
+ return s
+
+ @staticmethod
+ def from_lattice(
+ lattice: k2.Fsa,
+ num_paths: int,
+ use_double_scores: bool = True,
+ lattice_score_scale: float = 0.5,
+ ) -> "Nbest":
+ """Construct an Nbest object by **sampling** `num_paths` from a lattice.
+
+ Each sampled path is a linear FSA.
+
+ We assume `lattice.labels` contains token IDs and `lattice.aux_labels`
+ contains word IDs.
+
+ Args:
+ lattice:
+ An FsaVec with axes [utt][state][arc].
+ num_paths:
+ Number of paths to **sample** from the lattice
+ using :func:`k2.random_paths`.
+ use_double_scores:
+ True to use double precision in :func:`k2.random_paths`.
+ False to use single precision.
+ scale:
+ Scale `lattice.score` before passing it to :func:`k2.random_paths`.
+ A smaller value leads to more unique paths at the risk of being not
+ to sample the path with the best score.
+ Returns:
+ Return an Nbest instance.
+ """
+ saved_scores = lattice.scores.clone()
+ lattice.scores *= lattice_score_scale
+ # path is a ragged tensor with dtype torch.int32.
+ # It has three axes [utt][path][arc_pos]
+ path = k2.random_paths(
+ lattice, num_paths=num_paths, use_double_scores=use_double_scores
+ )
+ lattice.scores = saved_scores
+
+ # word_seq is a k2.RaggedTensor sharing the same shape as `path`
+ # but it contains word IDs. Note that it also contains 0s and -1s.
+ # The last entry in each sublist is -1.
+ # It axes is [utt][path][word_id]
+ if isinstance(lattice.aux_labels, torch.Tensor):
+ word_seq = k2.ragged.index(lattice.aux_labels, path)
+ else:
+ word_seq = lattice.aux_labels.index(path)
+ word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
+
+ # Each utterance has `num_paths` paths but some of them transduces
+ # to the same word sequence, so we need to remove repeated word
+ # sequences within an utterance. After removing repeats, each utterance
+ # contains different number of paths
+ #
+ # `new2old` is a 1-D torch.Tensor mapping from the output path index
+ # to the input path index.
+ _, _, new2old = word_seq.unique(
+ need_num_repeats=False, need_new2old_indexes=True
+ )
+
+ # kept_path is a ragged tensor with dtype torch.int32.
+ # It has axes [utt][path][arc_pos]
+ kept_path, _ = path.index(new2old, axis=1, need_value_indexes=False)
+
+ # utt_to_path_shape has axes [utt][path]
+ utt_to_path_shape = kept_path.shape.get_layer(0)
+
+ # Remove the utterance axis.
+ # Now kept_path has only two axes [path][arc_pos]
+ kept_path = kept_path.remove_axis(0)
+
+ # labels is a ragged tensor with 2 axes [path][token_id]
+ # Note that it contains -1s.
+ labels = k2.ragged.index(lattice.labels.contiguous(), kept_path)
+
+ # Remove -1 from labels as we will use it to construct a linear FSA
+ labels = labels.remove_values_eq(-1)
+
+ if isinstance(lattice.aux_labels, k2.RaggedTensor):
+ # lattice.aux_labels is a ragged tensor with dtype torch.int32.
+ # It has 2 axes [arc][word], so aux_labels is also a ragged tensor
+ # with 2 axes [arc][word]
+ aux_labels, _ = lattice.aux_labels.index(
+ indexes=kept_path.values, axis=0, need_value_indexes=False
+ )
+ else:
+ assert isinstance(lattice.aux_labels, torch.Tensor)
+ aux_labels = k2.index_select(lattice.aux_labels, kept_path.values)
+ # aux_labels is a 1-D torch.Tensor. It also contains -1 and 0.
+
+ fsa = k2.linear_fsa(labels)
+ fsa.aux_labels = aux_labels
+ # Caution: fsa.scores are all 0s.
+ # `fsa` has only one extra attribute: aux_labels.
+ return Nbest(fsa=fsa, shape=utt_to_path_shape)
+
+ def intersect(self, lattice: k2.Fsa, use_double_scores=True) -> "Nbest":
+ """Intersect this Nbest object with a lattice, get 1-best
+ path from the resulting FsaVec, and return a new Nbest object.
+
+ The purpose of this function is to attach scores to an Nbest.
+
+ Args:
+ lattice:
+ An FsaVec with axes [utt][state][arc]. If it has `aux_labels`, then
+ we assume its `labels` are token IDs and `aux_labels` are word IDs.
+ If it has only `labels`, we assume its `labels` are word IDs.
+ use_double_scores:
+ True to use double precision when computing shortest path.
+ False to use single precision.
+ Returns:
+ Return a new Nbest. This new Nbest shares the same shape with `self`,
+ while its `fsa` is the 1-best path from intersecting `self.fsa` and
+ `lattice`. Also, its `fsa` has non-zero scores and inherits attributes
+ for `lattice`.
+ """
+ # Note: We view each linear FSA as a word sequence
+ # and we use the passed lattice to give each word sequence a score.
+ #
+ # We are not viewing each linear FSAs as a token sequence.
+ #
+ # So we use k2.invert() here.
+
+ # We use a word fsa to intersect with k2.invert(lattice)
+ word_fsa = k2.invert(self.fsa)
+
+ if hasattr(lattice, "aux_labels"):
+ # delete token IDs as it is not needed
+ del word_fsa.aux_labels
+
+ word_fsa.scores.zero_()
+ word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops(
+ word_fsa
+ )
+
+ path_to_utt_map = self.shape.row_ids(1)
+
+ if hasattr(lattice, "aux_labels"):
+ # lattice has token IDs as labels and word IDs as aux_labels.
+ # inv_lattice has word IDs as labels and token IDs as aux_labels
+ inv_lattice = k2.invert(lattice)
+ inv_lattice = k2.arc_sort(inv_lattice)
+ else:
+ inv_lattice = k2.arc_sort(lattice)
+
+ if inv_lattice.shape[0] == 1:
+ path_lattice = _intersect_device(
+ inv_lattice,
+ word_fsa_with_epsilon_loops,
+ b_to_a_map=torch.zeros_like(path_to_utt_map),
+ sorted_match_a=True,
+ )
+ else:
+ path_lattice = _intersect_device(
+ inv_lattice,
+ word_fsa_with_epsilon_loops,
+ b_to_a_map=path_to_utt_map,
+ sorted_match_a=True,
+ )
+
+ # path_lattice has word IDs as labels and token IDs as aux_labels
+ path_lattice = k2.top_sort(k2.connect(path_lattice))
+
+ one_best = k2.shortest_path(
+ path_lattice, use_double_scores=use_double_scores
+ )
+
+ one_best = k2.invert(one_best)
+ # Now one_best has token IDs as labels and word IDs as aux_labels
+
+ return Nbest(fsa=one_best, shape=self.shape)
+
+ def compute_am_scores(self) -> k2.RaggedTensor:
+ """Compute AM scores of each linear FSA (i.e., each path within
+ an utterance).
+
+ Hint:
+ `self.fsa.scores` contains two parts: acoustic scores (AM scores)
+ and n-gram language model scores (LM scores).
+
+ Caution:
+ We require that ``self.fsa`` has an attribute ``lm_scores``.
+
+ Returns:
+ Return a ragged tensor with 2 axes [utt][path_scores].
+ Its dtype is torch.float64.
+ """
+ saved_scores = self.fsa.scores
+
+ # The `scores` of every arc consists of `am_scores` and `lm_scores`
+ self.fsa.scores = self.fsa.scores - self.fsa.lm_scores
+
+ am_scores = self.fsa.get_tot_scores(
+ use_double_scores=True, log_semiring=False
+ )
+ self.fsa.scores = saved_scores
+
+ return k2.RaggedTensor(self.shape, am_scores)
+
+ def compute_lm_scores(self) -> k2.RaggedTensor:
+ """Compute LM scores of each linear FSA (i.e., each path within
+ an utterance).
+
+ Hint:
+ `self.fsa.scores` contains two parts: acoustic scores (AM scores)
+ and n-gram language model scores (LM scores).
+
+ Caution:
+ We require that ``self.fsa`` has an attribute ``lm_scores``.
+
+ Returns:
+ Return a ragged tensor with 2 axes [utt][path_scores].
+ Its dtype is torch.float64.
+ """
+ saved_scores = self.fsa.scores
+
+ # The `scores` of every arc consists of `am_scores` and `lm_scores`
+ self.fsa.scores = self.fsa.lm_scores
+
+ lm_scores = self.fsa.get_tot_scores(
+ use_double_scores=True, log_semiring=False
+ )
+ self.fsa.scores = saved_scores
+
+ return k2.RaggedTensor(self.shape, lm_scores)
+
+ def tot_scores(self) -> k2.RaggedTensor:
+ """Get total scores of FSAs in this Nbest.
+
+ Note:
+ Since FSAs in Nbest are just linear FSAs, log-semiring
+ and tropical semiring produce the same total scores.
+
+ Returns:
+ Return a ragged tensor with two axes [utt][path_scores].
+ Its dtype is torch.float64.
+ """
+ scores = self.fsa.get_tot_scores(
+ use_double_scores=True, log_semiring=False
+ )
+ return k2.RaggedTensor(self.shape, scores)
+
+ def build_levenshtein_graphs(self) -> k2.Fsa:
+ """Return an FsaVec with axes [utt][state][arc]."""
+ word_ids = get_texts(self.fsa, return_ragged=True)
+ return k2.levenshtein_graph(word_ids)
+
+
def one_best_decoding(
- lattice: k2.Fsa, use_double_scores: bool = True
+ lattice: k2.Fsa,
+ use_double_scores: bool = True,
) -> k2.Fsa:
"""Get the best path from a lattice.
@@ -179,200 +446,143 @@ def nbest_decoding(
lattice: k2.Fsa,
num_paths: int,
use_double_scores: bool = True,
- scale: float = 1.0,
+ lattice_score_scale: float = 1.0,
) -> k2.Fsa:
"""It implements something like CTC prefix beam search using n-best lists.
- The basic idea is to first extra n-best paths from the given lattice,
- build a word seqs from these paths, and compute the total scores
- of these sequences in the log-semiring. The one with the max score
+ The basic idea is to first extract `num_paths` paths from the given lattice,
+ build a word sequence from these paths, and compute the total scores
+ of the word sequence in the tropical semiring. The one with the max score
is used as the decoding output.
Caution:
Don't be confused by `best` in the name `n-best`. Paths are selected
- randomly, not by ranking their scores.
+ **randomly**, not by ranking their scores.
+
+ Hint:
+ This decoding method is for demonstration only and it does
+ not produce a lower WER than :func:`one_best_decoding`.
Args:
lattice:
- The decoding lattice, returned by :func:`get_lattice`.
+ The decoding lattice, e.g., can be the return value of
+ :func:`get_lattice`. It has 3 axes [utt][state][arc].
num_paths:
It specifies the size `n` in n-best. Note: Paths are selected randomly
- and those containing identical word sequences are remove dand only one
+ and those containing identical word sequences are removed and only one
of them is kept.
use_double_scores:
True to use double precision floating point in the computation.
False to use single precision.
- scale:
- It's the scale applied to the lattice.scores. A smaller value
- yields more unique paths.
+ lattice_score_scale:
+ It's the scale applied to the `lattice.scores`. A smaller value
+ leads to more unique paths at the risk of missing the correct path.
Returns:
- An FsaVec containing linear FSAs.
+ An FsaVec containing **linear** FSAs. It axes are [utt][state][arc].
"""
- path = _get_random_paths(
+ nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
- scale=scale,
+ lattice_score_scale=lattice_score_scale,
)
+ # nbest.fsa.scores contains 0s
- # word_seq is a k2.RaggedInt sharing the same shape as `path`
- # but it contains word IDs. Note that it also contains 0s and -1s.
- # The last entry in each sublist is -1.
- word_seq = k2.index(lattice.aux_labels, path)
- # Note: the above operation supports also the case when
- # lattice.aux_labels is a ragged tensor. In that case,
- # `remove_axis=True` is used inside the pybind11 binding code,
- # so the resulting `word_seq` still has 3 axes, like `path`.
- # The 3 axes are [seq][path][word_id]
+ nbest = nbest.intersect(lattice)
+ # now nbest.fsa.scores gets assigned
- # Remove 0 (epsilon) and -1 from word_seq
- word_seq = k2.ragged.remove_values_leq(word_seq, 0)
+ # max_indexes contains the indexes for the path with the maximum score
+ # within an utterance.
+ max_indexes = nbest.tot_scores().argmax()
- # Remove sequences with identical word sequences.
- #
- # k2.ragged.unique_sequences will reorder paths within a seq.
- # `new2old` is a 1-D torch.Tensor mapping from the output path index
- # to the input path index.
- # new2old.numel() == unique_word_seqs.tot_size(1)
- unique_word_seq, _, new2old = k2.ragged.unique_sequences(
- word_seq, need_num_repeats=False, need_new2old_indexes=True
- )
- # Note: unique_word_seq still has the same axes as word_seq
-
- seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0)
-
- # path_to_seq_map is a 1-D torch.Tensor.
- # path_to_seq_map[i] is the seq to which the i-th path belongs
- path_to_seq_map = seq_to_path_shape.row_ids(1)
-
- # Remove the seq axis.
- # Now unique_word_seq has only two axes [path][word]
- unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0)
-
- # word_fsa is an FsaVec with axes [path][state][arc]
- word_fsa = k2.linear_fsa(unique_word_seq)
-
- # add epsilon self loops since we will use
- # k2.intersect_device, which treats epsilon as a normal symbol
- word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
-
- # lattice has token IDs as labels and word IDs as aux_labels.
- # inv_lattice has word IDs as labels and token IDs as aux_labels
- inv_lattice = k2.invert(lattice)
- inv_lattice = k2.arc_sort(inv_lattice)
-
- path_lattice = _intersect_device(
- inv_lattice,
- word_fsa_with_epsilon_loops,
- b_to_a_map=path_to_seq_map,
- sorted_match_a=True,
- )
- # path_lat has word IDs as labels and token IDs as aux_labels
-
- path_lattice = k2.top_sort(k2.connect(path_lattice))
-
- tot_scores = path_lattice.get_tot_scores(
- use_double_scores=use_double_scores, log_semiring=False
- )
-
- # RaggedFloat currently supports float32 only.
- # If Ragged is wrapped, we can use k2.RaggedDouble here
- ragged_tot_scores = k2.RaggedFloat(
- seq_to_path_shape, tot_scores.to(torch.float32)
- )
-
- argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
-
- # Since we invoked `k2.ragged.unique_sequences`, which reorders
- # the index from `path`, we use `new2old` here to convert argmax_indexes
- # to the indexes into `path`.
- #
- # Use k2.index here since argmax_indexes' dtype is torch.int32
- best_path_indexes = k2.index(new2old, argmax_indexes)
-
- path_2axes = k2.ragged.remove_axis(path, 0)
-
- # best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
- best_path = k2.index(path_2axes, best_path_indexes)
-
- # labels is a k2.RaggedInt with 2 axes [path][token_id]
- # Note that it contains -1s.
- labels = k2.index(lattice.labels.contiguous(), best_path)
-
- labels = k2.ragged.remove_values_eq(labels, -1)
-
- # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so
- # aux_labels is also a k2.RaggedInt with 2 axes
- aux_labels = k2.index(lattice.aux_labels, best_path.values())
-
- best_path_fsa = k2.linear_fsa(labels)
- best_path_fsa.aux_labels = aux_labels
- return best_path_fsa
+ best_path = k2.index_fsa(nbest.fsa, max_indexes)
+ return best_path
-def compute_am_and_lm_scores(
+def nbest_oracle(
lattice: k2.Fsa,
- word_fsa_with_epsilon_loops: k2.Fsa,
- path_to_seq_map: torch.Tensor,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- """Compute AM scores of n-best lists (represented as word_fsas).
+ num_paths: int,
+ ref_texts: List[str],
+ word_table: k2.SymbolTable,
+ use_double_scores: bool = True,
+ lattice_score_scale: float = 0.5,
+ oov: str = "",
+) -> Dict[str, List[List[int]]]:
+ """Select the best hypothesis given a lattice and a reference transcript.
+
+ The basic idea is to extract `num_paths` paths from the given lattice,
+ unique them, and select the one that has the minimum edit distance with
+ the corresponding reference transcript as the decoding output.
+
+ The decoding result returned from this function is the best result that
+ we can obtain using n-best decoding with all kinds of rescoring techniques.
+
+ This function is useful to tune the value of `lattice_score_scale`.
Args:
lattice:
- An FsaVec, e.g., the return value of :func:`get_lattice`
- It must have the attribute `lm_scores`.
- word_fsa_with_epsilon_loops:
- An FsaVec representing an n-best list. Note that it has been processed
- by `k2.add_epsilon_self_loops`.
- path_to_seq_map:
- A 1-D torch.Tensor with dtype torch.int32. path_to_seq_map[i] indicates
- which sequence the i-th Fsa in word_fsa_with_epsilon_loops belongs to.
- path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0().
- Returns:
- Return a tuple containing two 1-D torch.Tensors: (am_scores, lm_scores).
- Each tensor's `numel()' equals to `word_fsas_with_epsilon_loops.shape[0]`
+ An FsaVec with axes [utt][state][arc].
+ Note: We assume its `aux_labels` contains word IDs.
+ num_paths:
+ The size of `n` in n-best.
+ ref_texts:
+ A list of reference transcript. Each entry contains space(s)
+ separated words
+ word_table:
+ It is the word symbol table.
+ use_double_scores:
+ True to use double precision for computation. False to use
+ single precision.
+ lattice_score_scale:
+ It's the scale applied to the lattice.scores. A smaller value
+ yields more unique paths.
+ oov:
+ The out of vocabulary word.
+ Return:
+ Return a dict. Its key contains the information about the parameters
+ when calling this function, while its value contains the decoding output.
+ `len(ans_dict) == len(ref_texts)`
"""
- assert len(lattice.shape) == 3
- assert hasattr(lattice, "lm_scores")
+ device = lattice.device
- # k2.compose() currently does not support b_to_a_map. To void
- # replicating `lats`, we use k2.intersect_device here.
- #
- # lattice has token IDs as `labels` and word IDs as aux_labels, so we
- # need to invert it here.
- inv_lattice = k2.invert(lattice)
-
- # Now the `labels` of inv_lattice are word IDs (a 1-D torch.Tensor)
- # and its `aux_labels` are token IDs ( a k2.RaggedInt with 2 axes)
-
- # Remove its `aux_labels` since it is not needed in the
- # following computation
- del inv_lattice.aux_labels
- inv_lattice = k2.arc_sort(inv_lattice)
-
- path_lattice = _intersect_device(
- inv_lattice,
- word_fsa_with_epsilon_loops,
- b_to_a_map=path_to_seq_map,
- sorted_match_a=True,
+ nbest = Nbest.from_lattice(
+ lattice=lattice,
+ num_paths=num_paths,
+ use_double_scores=use_double_scores,
+ lattice_score_scale=lattice_score_scale,
)
- path_lattice = k2.top_sort(k2.connect(path_lattice))
+ hyps = nbest.build_levenshtein_graphs()
- # The `scores` of every arc consists of `am_scores` and `lm_scores`
- path_lattice.scores = path_lattice.scores - path_lattice.lm_scores
+ oov_id = word_table[oov]
+ word_ids_list = []
+ for text in ref_texts:
+ word_ids = []
+ for word in text.split():
+ if word in word_table:
+ word_ids.append(word_table[word])
+ else:
+ word_ids.append(oov_id)
+ word_ids_list.append(word_ids)
- am_scores = path_lattice.get_tot_scores(
- use_double_scores=True, log_semiring=False
+ refs = k2.levenshtein_graph(word_ids_list, device=device)
+
+ levenshtein_alignment = k2.levenshtein_alignment(
+ refs=refs,
+ hyps=hyps,
+ hyp_to_ref_map=nbest.shape.row_ids(1),
+ sorted_match_ref=True,
)
- path_lattice.scores = path_lattice.lm_scores
-
- lm_scores = path_lattice.get_tot_scores(
- use_double_scores=True, log_semiring=False
+ tot_scores = levenshtein_alignment.get_tot_scores(
+ use_double_scores=False, log_semiring=False
)
+ ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
- return am_scores.to(torch.float32), lm_scores.to(torch.float32)
+ max_indexes = ragged_tot_scores.argmax()
+
+ best_path = k2.index_fsa(nbest.fsa, max_indexes)
+ return best_path
def rescore_with_n_best_list(
@@ -380,34 +590,32 @@ def rescore_with_n_best_list(
G: k2.Fsa,
num_paths: int,
lm_scale_list: List[float],
- scale: float = 1.0,
+ lattice_score_scale: float = 1.0,
+ use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]:
- """Decode using n-best list with LM rescoring.
-
- `lattice` is a decoding lattice with 3 axes. This function first
- extracts `num_paths` paths from `lattice` for each sequence using
- `k2.random_paths`. The `am_scores` of these paths are computed.
- For each path, its `lm_scores` is computed using `G` (which is an LM).
- The final `tot_scores` is the sum of `am_scores` and `lm_scores`.
- The path with the largest `tot_scores` within a sequence is used
- as the decoding output.
+ """Rescore an n-best list with an n-gram LM.
+ The path with the maximum score is used as the decoding output.
Args:
lattice:
- An FsaVec. It can be the return value of :func:`get_lattice`.
+ An FsaVec with axes [utt][state][arc]. It must have the following
+ attributes: ``aux_labels`` and ``lm_scores``. Its labels are
+ token IDs and ``aux_labels`` word IDs.
G:
- An FsaVec representing the language model (LM). Note that it
- is an FsaVec, but it contains only one Fsa.
+ An FsaVec containing only a single FSA. It is an n-gram LM.
num_paths:
- It is the size `n` in `n-best` list.
+ Size of nbest list.
lm_scale_list:
- A list containing lm_scale values.
- scale:
- It's the scale applied to the lattice.scores. A smaller value
- yields more unique paths.
+ A list of float representing LM score scales.
+ lattice_score_scale:
+ Scale to be applied to ``lattice.score`` when sampling paths
+ using ``k2.random_paths``.
+ use_double_scores:
+ True to use double precision during computation. False to use
+ single precision.
Returns:
A dict of FsaVec, whose key is an lm_scale and the value is the
- best decoding path for each sequence in the lattice.
+ best decoding path for each utterance in the lattice.
"""
device = lattice.device
@@ -419,112 +627,32 @@ def rescore_with_n_best_list(
assert G.device == device
assert hasattr(G, "aux_labels") is False
- path = _get_random_paths(
+ nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
- use_double_scores=True,
- scale=scale,
+ use_double_scores=use_double_scores,
+ lattice_score_scale=lattice_score_scale,
)
+ # nbest.fsa.scores are all 0s at this point
- # word_seq is a k2.RaggedInt sharing the same shape as `path`
- # but it contains word IDs. Note that it also contains 0s and -1s.
- # The last entry in each sublist is -1.
- word_seq = k2.index(lattice.aux_labels, path)
+ nbest = nbest.intersect(lattice)
+ # Now nbest.fsa has its scores set
+ assert hasattr(nbest.fsa, "lm_scores")
- # Remove epsilons and -1 from word_seq
- word_seq = k2.ragged.remove_values_leq(word_seq, 0)
+ am_scores = nbest.compute_am_scores()
- # Remove paths that has identical word sequences.
- #
- # unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word]
- # except that there are no repeated paths with the same word_seq
- # within a sequence.
- #
- # num_repeats is also a k2.RaggedInt with 2 axes containing the
- # multiplicities of each path.
- # num_repeats.num_elements() == unique_word_seqs.tot_size(1)
- #
- # Since k2.ragged.unique_sequences will reorder paths within a seq,
- # `new2old` is a 1-D torch.Tensor mapping from the output path index
- # to the input path index.
- # new2old.numel() == unique_word_seqs.tot_size(1)
- unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
- word_seq, need_num_repeats=True, need_new2old_indexes=True
- )
-
- seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0)
-
- # path_to_seq_map is a 1-D torch.Tensor.
- # path_to_seq_map[i] is the seq to which the i-th path
- # belongs.
- path_to_seq_map = seq_to_path_shape.row_ids(1)
-
- # Remove the seq axis.
- # Now unique_word_seq has only two axes [path][word]
- unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0)
-
- # word_fsa is an FsaVec with axes [path][state][arc]
- word_fsa = k2.linear_fsa(unique_word_seq)
-
- word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
-
- am_scores, _ = compute_am_and_lm_scores(
- lattice, word_fsa_with_epsilon_loops, path_to_seq_map
- )
-
- # Now compute lm_scores
- b_to_a_map = torch.zeros_like(path_to_seq_map)
- lm_path_lattice = _intersect_device(
- G,
- word_fsa_with_epsilon_loops,
- b_to_a_map=b_to_a_map,
- sorted_match_a=True,
- )
- lm_path_lattice = k2.top_sort(k2.connect(lm_path_lattice))
- lm_scores = lm_path_lattice.get_tot_scores(
- use_double_scores=True, log_semiring=False
- )
-
- path_2axes = k2.ragged.remove_axis(path, 0)
+ nbest = nbest.intersect(G)
+ # Now nbest contains only lm scores
+ lm_scores = nbest.tot_scores()
ans = dict()
for lm_scale in lm_scale_list:
- tot_scores = am_scores / lm_scale + lm_scores
-
- # Remember that we used `k2.ragged.unique_sequences` to remove repeated
- # paths to avoid redundant computation in `k2.intersect_device`.
- # Now we use `num_repeats` to correct the scores for each path.
- #
- # NOTE(fangjun): It is commented out as it leads to a worse WER
- # tot_scores = tot_scores * num_repeats.values()
-
- ragged_tot_scores = k2.RaggedFloat(
- seq_to_path_shape, tot_scores.to(torch.float32)
- )
- argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
-
- # Use k2.index here since argmax_indexes' dtype is torch.int32
- best_path_indexes = k2.index(new2old, argmax_indexes)
-
- # best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
- best_path = k2.index(path_2axes, best_path_indexes)
-
- # labels is a k2.RaggedInt with 2 axes [path][phone_id]
- # Note that it contains -1s.
- labels = k2.index(lattice.labels.contiguous(), best_path)
-
- labels = k2.ragged.remove_values_eq(labels, -1)
-
- # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so
- # aux_labels is also a k2.RaggedInt with 2 axes
- aux_labels = k2.index(lattice.aux_labels, best_path.values())
-
- best_path_fsa = k2.linear_fsa(labels)
- best_path_fsa.aux_labels = aux_labels
-
+ tot_scores = am_scores.values / lm_scale + lm_scores.values
+ tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
+ max_indexes = tot_scores.argmax()
+ best_path = k2.index_fsa(nbest.fsa, max_indexes)
key = f"lm_scale_{lm_scale}"
- ans[key] = best_path_fsa
-
+ ans[key] = best_path
return ans
@@ -532,25 +660,40 @@ def rescore_with_whole_lattice(
lattice: k2.Fsa,
G_with_epsilon_loops: k2.Fsa,
lm_scale_list: Optional[List[float]] = None,
+ use_double_scores: bool = True,
) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
- """Use whole lattice to rescore.
+ """Intersect the lattice with an n-gram LM and use shortest path
+ to decode.
+
+ The input lattice is obtained by intersecting `HLG` with
+ a DenseFsaVec, where the `G` in `HLG` is in general a 3-gram LM.
+ The input `G_with_epsilon_loops` is usually a 4-gram LM. You can consider
+ this function as a second pass decoding. In the first pass decoding, we
+ use a small G, while we use a larger G in the second pass decoding.
Args:
lattice:
- An FsaVec It can be the return value of :func:`get_lattice`.
+ An FsaVec with axes [utt][state][arc]. Its `aux_lables` are word IDs.
+ It must have an attribute `lm_scores`.
G_with_epsilon_loops:
- An FsaVec representing the language model (LM). Note that it
- is an FsaVec, but it contains only one Fsa.
+ An FsaVec containing only a single FSA. It contains epsilon self-loops.
+ It is an acceptor and its labels are word IDs.
lm_scale_list:
- A list containing lm_scale values or None.
+ Optional. If none, return the intersection of `lattice` and
+ `G_with_epsilon_loops`.
+ If not None, it contains a list of values to scale LM scores.
+ For each scale, there is a corresponding decoding result contained in
+ the resulting dict.
+ use_double_scores:
+ True to use double precision in the computation.
+ False to use single precision.
Returns:
- If lm_scale_list is not None, return a dict of FsaVec, whose key
- is a lm_scale and the value represents the best decoding path for
- each sequence in the lattice.
- If lm_scale_list is not None, return a lattice that is rescored
- with the given LM.
+ If `lm_scale_list` is None, return a new lattice which is the intersection
+ result of `lattice` and `G_with_epsilon_loops`.
+ Otherwise, return a dict whose key is an entry in `lm_scale_list` and the
+ value is the decoding result (i.e., an FsaVec containing linear FSAs).
"""
- assert len(lattice.shape) == 3
+ # Nbest is not used in this function
assert hasattr(lattice, "lm_scores")
assert G_with_epsilon_loops.shape == (1, None, None)
@@ -558,19 +701,22 @@ def rescore_with_whole_lattice(
lattice.scores = lattice.scores - lattice.lm_scores
# We will use lm_scores from G, so remove lats.lm_scores here
del lattice.lm_scores
- assert hasattr(lattice, "lm_scores") is False
assert hasattr(G_with_epsilon_loops, "lm_scores")
# Now, lattice.scores contains only am_scores
# inv_lattice has word IDs as labels.
- # Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt
+ # Its `aux_labels` is token IDs
inv_lattice = k2.invert(lattice)
num_seqs = lattice.shape[0]
b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
- while True:
+
+ max_loop_count = 10
+ loop_count = 0
+ while loop_count <= max_loop_count:
+ loop_count += 1
try:
rescoring_lattice = k2.intersect_device(
G_with_epsilon_loops,
@@ -586,12 +732,15 @@ def rescore_with_whole_lattice(
f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
)
- # NOTE(fangjun): The choice of the threshold 1e-7 is arbitrary here
- # to avoid OOM. We may need to fine tune it.
- inv_lattice = k2.prune_on_arc_post(inv_lattice, 1e-7, True)
+ # NOTE(fangjun): The choice of the threshold 1e-9 is arbitrary here
+ # to avoid OOM. You may need to fine tune it.
+ inv_lattice = k2.prune_on_arc_post(inv_lattice, 1e-9, True)
logging.info(
f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
)
+ if loop_count > max_loop_count:
+ logging.info("Return None as the resulting lattice is too large")
+ return None
# lat has token IDs as labels
# and word IDs as aux_labels.
@@ -601,112 +750,37 @@ def rescore_with_whole_lattice(
return lat
ans = dict()
- #
- # The following implements
- # scores = (scores - lm_scores)/lm_scale + lm_scores
- # = scores/lm_scale + lm_scores*(1 - 1/lm_scale)
- #
saved_am_scores = lat.scores - lat.lm_scores
for lm_scale in lm_scale_list:
am_scores = saved_am_scores / lm_scale
lat.scores = am_scores + lat.lm_scores
- best_path = k2.shortest_path(lat, use_double_scores=True)
+ best_path = k2.shortest_path(lat, use_double_scores=use_double_scores)
key = f"lm_scale_{lm_scale}"
ans[key] = best_path
return ans
-def nbest_oracle(
- lattice: k2.Fsa,
- num_paths: int,
- ref_texts: List[str],
- word_table: k2.SymbolTable,
- scale: float = 1.0,
-) -> Dict[str, List[List[int]]]:
- """Select the best hypothesis given a lattice and a reference transcript.
-
- The basic idea is to extract n paths from the given lattice, unique them,
- and select the one that has the minimum edit distance with the corresponding
- reference transcript as the decoding output.
-
- The decoding result returned from this function is the best result that
- we can obtain using n-best decoding with all kinds of rescoring techniques.
-
- Args:
- lattice:
- An FsaVec. It can be the return value of :func:`get_lattice`.
- Note: We assume its aux_labels contain word IDs.
- num_paths:
- The size of `n` in n-best.
- ref_texts:
- A list of reference transcript. Each entry contains space(s)
- separated words
- word_table:
- It is the word symbol table.
- scale:
- It's the scale applied to the lattice.scores. A smaller value
- yields more unique paths.
- Return:
- Return a dict. Its key contains the information about the parameters
- when calling this function, while its value contains the decoding output.
- `len(ans_dict) == len(ref_texts)`
- """
- path = _get_random_paths(
- lattice=lattice,
- num_paths=num_paths,
- use_double_scores=True,
- scale=scale,
- )
-
- word_seq = k2.index(lattice.aux_labels, path)
- word_seq = k2.ragged.remove_values_leq(word_seq, 0)
- unique_word_seq, _, _ = k2.ragged.unique_sequences(
- word_seq, need_num_repeats=False, need_new2old_indexes=False
- )
- unique_word_ids = k2.ragged.to_list(unique_word_seq)
- assert len(unique_word_ids) == len(ref_texts)
- # unique_word_ids[i] contains all hypotheses of the i-th utterance
-
- results = []
- for hyps, ref in zip(unique_word_ids, ref_texts):
- # Note hyps is a list-of-list ints
- # Each sublist contains a hypothesis
- ref_words = ref.strip().split()
- # CAUTION: We don't convert ref_words to ref_words_ids
- # since there may exist OOV words in ref_words
- best_hyp_words = None
- min_error = float("inf")
- for hyp_words in hyps:
- hyp_words = [word_table[i] for i in hyp_words]
- this_error = kaldialign.edit_distance(ref_words, hyp_words)["total"]
- if this_error < min_error:
- min_error = this_error
- best_hyp_words = hyp_words
- results.append(best_hyp_words)
-
- return {f"nbest_{num_paths}_scale_{scale}_oracle": results}
-
-
def rescore_with_attention_decoder(
lattice: k2.Fsa,
num_paths: int,
- model: nn.Module,
+ model: torch.nn.Module,
memory: torch.Tensor,
memory_key_padding_mask: Optional[torch.Tensor],
sos_id: int,
eos_id: int,
- scale: float = 1.0,
+ lattice_score_scale: float = 1.0,
ngram_lm_scale: Optional[float] = None,
attention_scale: Optional[float] = None,
+ use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]:
- """This function extracts n paths from the given lattice and uses
- an attention decoder to rescore them. The path with the highest
- score is used as the decoding output.
+ """This function extracts `num_paths` paths from the given lattice and uses
+ an attention decoder to rescore them. The path with the highest score is
+ the decoding output.
Args:
lattice:
- An FsaVec. It can be the return value of :func:`get_lattice`.
+ An FsaVec with axes [utt][state][arc].
num_paths:
Number of paths to extract from the given lattice for rescoring.
model:
@@ -715,16 +789,16 @@ def rescore_with_attention_decoder(
memory:
The encoder memory of the given model. It is the output of
the last torch.nn.TransformerEncoder layer in the given model.
- Its shape is `[T, N, C]`.
+ Its shape is `(T, N, C)`.
memory_key_padding_mask:
- The padding mask for memory with shape [N, T].
+ The padding mask for memory with shape `(N, T)`.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
- scale:
- It's the scale applied to the lattice.scores. A smaller value
- yields more unique paths.
+ lattice_score_scale:
+ It's the scale applied to `lattice.scores`. A smaller value
+ leads to more unique paths at the risk of missing the correct path.
ngram_lm_scale:
Optional. It specifies the scale for n-gram LM scores.
attention_scale:
@@ -732,97 +806,47 @@ def rescore_with_attention_decoder(
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
- best decoding path for each sequence in the lattice.
+ best decoding path for each utterance in the lattice.
"""
- # First, extract `num_paths` paths for each sequence.
- # path is a k2.RaggedInt with axes [seq][path][arc_pos]
- path = _get_random_paths(
+ nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
- use_double_scores=True,
- scale=scale,
+ use_double_scores=use_double_scores,
+ lattice_score_scale=lattice_score_scale,
)
+ # nbest.fsa.scores are all 0s at this point
- # word_seq is a k2.RaggedInt sharing the same shape as `path`
- # but it contains word IDs. Note that it also contains 0s and -1s.
- # The last entry in each sublist is -1.
- word_seq = k2.index(lattice.aux_labels, path)
+ nbest = nbest.intersect(lattice)
+ # Now nbest.fsa has its scores set.
+ # Also, nbest.fsa inherits the attributes from `lattice`.
+ assert hasattr(nbest.fsa, "lm_scores")
- # Remove epsilons and -1 from word_seq
- word_seq = k2.ragged.remove_values_leq(word_seq, 0)
+ am_scores = nbest.compute_am_scores()
+ ngram_lm_scores = nbest.compute_lm_scores()
- # Remove paths that has identical word sequences.
- #
- # unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word]
- # except that there are no repeated paths with the same word_seq
- # within a sequence.
- #
- # num_repeats is also a k2.RaggedInt with 2 axes containing the
- # multiplicities of each path.
- # num_repeats.num_elements() == unique_word_seqs.tot_size(1)
- #
- # Since k2.ragged.unique_sequences will reorder paths within a seq,
- # `new2old` is a 1-D torch.Tensor mapping from the output path index
- # to the input path index.
- # new2old.numel() == unique_word_seq.tot_size(1)
- unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
- word_seq, need_num_repeats=True, need_new2old_indexes=True
- )
+ # The `tokens` attribute is set inside `compile_hlg.py`
+ assert hasattr(nbest.fsa, "tokens")
+ assert isinstance(nbest.fsa.tokens, torch.Tensor)
- seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0)
-
- # path_to_seq_map is a 1-D torch.Tensor.
- # path_to_seq_map[i] is the seq to which the i-th path
- # belongs.
- path_to_seq_map = seq_to_path_shape.row_ids(1)
-
- # Remove the seq axis.
- # Now unique_word_seq has only two axes [path][word]
- unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0)
-
- # word_fsa is an FsaVec with axes [path][state][arc]
- word_fsa = k2.linear_fsa(unique_word_seq)
-
- word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
-
- am_scores, ngram_lm_scores = compute_am_and_lm_scores(
- lattice, word_fsa_with_epsilon_loops, path_to_seq_map
- )
- # Now we use the attention decoder to compute another
- # score: attention_scores.
- #
- # To do that, we have to get the input and output for the attention
- # decoder.
-
- # CAUTION: The "tokens" attribute is set in the file
- # local/compile_hlg.py
- token_seq = k2.index(lattice.tokens, path)
-
- # Remove epsilons and -1 from token_seq
- token_seq = k2.ragged.remove_values_leq(token_seq, 0)
-
- # Remove the seq axis.
- token_seq = k2.ragged.remove_axis(token_seq, 0)
-
- token_seq, _ = k2.ragged.index(
- token_seq, indexes=new2old, axis=0, need_value_indexes=False
- )
-
- # Now word in unique_word_seq has its corresponding token IDs.
- token_ids = k2.ragged.to_list(token_seq)
-
- num_word_seqs = new2old.numel()
-
- path_to_seq_map_long = path_to_seq_map.to(torch.long)
- expanded_memory = memory.index_select(1, path_to_seq_map_long)
+ path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
+ # the shape of memory is (T, N, C), so we use axis=1 here
+ expanded_memory = memory.index_select(1, path_to_utt_map)
if memory_key_padding_mask is not None:
+ # The shape of memory_key_padding_mask is (N, T), so we
+ # use axis=0 here.
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
- 0, path_to_seq_map_long
+ 0, path_to_utt_map
)
else:
expanded_memory_key_padding_mask = None
+ # remove axis corresponding to states.
+ tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
+ tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
+ tokens = tokens.remove_values_leq(0)
+ token_ids = tokens.tolist()
+
nll = model.decoder_nll(
memory=expanded_memory,
memory_key_padding_mask=expanded_memory_key_padding_mask,
@@ -831,55 +855,36 @@ def rescore_with_attention_decoder(
eos_id=eos_id,
)
assert nll.ndim == 2
- assert nll.shape[0] == num_word_seqs
+ assert nll.shape[0] == len(token_ids)
attention_scores = -nll.sum(dim=1)
- assert attention_scores.ndim == 1
- assert attention_scores.numel() == num_word_seqs
if ngram_lm_scale is None:
- ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
+ ngram_lm_scale_list = [0.01, 0.05, 0.08]
+ ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else:
ngram_lm_scale_list = [ngram_lm_scale]
if attention_scale is None:
- attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
+ attention_scale_list = [0.01, 0.05, 0.08]
+ attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else:
attention_scale_list = [attention_scale]
- path_2axes = k2.ragged.remove_axis(path, 0)
-
ans = dict()
for n_scale in ngram_lm_scale_list:
for a_scale in attention_scale_list:
tot_scores = (
- am_scores
- + n_scale * ngram_lm_scores
+ am_scores.values
+ + n_scale * ngram_lm_scores.values
+ a_scale * attention_scores
)
- ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores)
- argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
-
- best_path_indexes = k2.index(new2old, argmax_indexes)
-
- # best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
- best_path = k2.index(path_2axes, best_path_indexes)
-
- # labels is a k2.RaggedInt with 2 axes [path][token_id]
- # Note that it contains -1s.
- labels = k2.index(lattice.labels.contiguous(), best_path)
-
- labels = k2.ragged.remove_values_eq(labels, -1)
-
- # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so
- # aux_labels is also a k2.RaggedInt with 2 axes
- aux_labels = k2.index(lattice.aux_labels, best_path.values())
-
- best_path_fsa = k2.linear_fsa(labels)
- best_path_fsa.aux_labels = aux_labels
+ ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
+ max_indexes = ragged_tot_scores.argmax()
+ best_path = k2.index_fsa(nbest.fsa, max_indexes)
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
- ans[key] = best_path_fsa
+ ans[key] = best_path
return ans
diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py
index 23ac247e8..b4c87d964 100644
--- a/icefall/graph_compiler.py
+++ b/icefall/graph_compiler.py
@@ -106,7 +106,7 @@ class CtcTrainingGraphCompiler(object):
word_ids_list = []
for text in texts:
word_ids = []
- for word in text.split(" "):
+ for word in text.split():
if word in self.word_table:
word_ids.append(self.word_table[word])
else:
diff --git a/icefall/lexicon.py b/icefall/lexicon.py
index f1127c7cf..6730bac49 100644
--- a/icefall/lexicon.py
+++ b/icefall/lexicon.py
@@ -157,7 +157,7 @@ class BpeLexicon(Lexicon):
lang_dir / "lexicon.txt"
)
- def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedInt:
+ def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedTensor:
"""Read a BPE lexicon from file and convert it to a
k2 ragged tensor.
@@ -200,19 +200,18 @@ class BpeLexicon(Lexicon):
)
values = torch.tensor(token_ids, dtype=torch.int32)
- return k2.RaggedInt(shape, values)
+ return k2.RaggedTensor(shape, values)
- def words_to_piece_ids(self, words: List[str]) -> k2.RaggedInt:
+ def words_to_piece_ids(self, words: List[str]) -> k2.RaggedTensor:
"""Convert a list of words to a ragged tensor contained
word piece IDs.
"""
word_ids = [self.word_table[w] for w in words]
word_ids = torch.tensor(word_ids, dtype=torch.int32)
- ragged, _ = k2.ragged.index(
- self.ragged_lexicon,
+ ragged, _ = self.ragged_lexicon.index(
indexes=word_ids,
- need_value_indexes=False,
axis=0,
+ need_value_indexes=False,
)
return ragged
diff --git a/icefall/utils.py b/icefall/utils.py
index 2994c2d47..23b4dd6c7 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -26,7 +26,6 @@ from pathlib import Path
from typing import Dict, Iterable, List, TextIO, Tuple, Union
import k2
-import k2.ragged as k2r
import kaldialign
import torch
import torch.distributed as dist
@@ -147,12 +146,20 @@ def get_env_info():
}
-# See
-# https://stackoverflow.com/questions/4984647/accessing-dict-keys-like-an-attribute # noqa
class AttributeDict(dict):
- __slots__ = ()
- __getattr__ = dict.__getitem__
- __setattr__ = dict.__setitem__
+ def __getattr__(self, key):
+ if key in self:
+ return self[key]
+ raise AttributeError(f"No such attribute '{key}'")
+
+ def __setattr__(self, key, value):
+ self[key] = value
+
+ def __delattr__(self, key):
+ if key in self:
+ del self[key]
+ return
+ raise AttributeError(f"No such attribute '{key}'")
def encode_supervisions(
@@ -187,7 +194,9 @@ def encode_supervisions(
return supervision_segments, texts
-def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
+def get_texts(
+ best_paths: k2.Fsa, return_ragged: bool = False
+) -> Union[List[List[int]], k2.RaggedTensor]:
"""Extract the texts (as word IDs) from the best-path FSAs.
Args:
best_paths:
@@ -195,30 +204,35 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't
be meaningful).
+ return_ragged:
+ True to return a ragged tensor with two axes [utt][word_id].
+ False to return a list-of-list word IDs.
Returns:
Returns a list of lists of int, containing the label sequences we
decoded.
"""
- if isinstance(best_paths.aux_labels, k2.RaggedInt):
+ if isinstance(best_paths.aux_labels, k2.RaggedTensor):
# remove 0's and -1's.
- aux_labels = k2r.remove_values_leq(best_paths.aux_labels, 0)
- aux_shape = k2r.compose_ragged_shapes(
- best_paths.arcs.shape(), aux_labels.shape()
- )
+ aux_labels = best_paths.aux_labels.remove_values_leq(0)
+ # TODO: change arcs.shape() to arcs.shape
+ aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
# remove the states and arcs axes.
- aux_shape = k2r.remove_axis(aux_shape, 1)
- aux_shape = k2r.remove_axis(aux_shape, 1)
- aux_labels = k2.RaggedInt(aux_shape, aux_labels.values())
+ aux_shape = aux_shape.remove_axis(1)
+ aux_shape = aux_shape.remove_axis(1)
+ aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
else:
# remove axis corresponding to states.
- aux_shape = k2r.remove_axis(best_paths.arcs.shape(), 1)
- aux_labels = k2.RaggedInt(aux_shape, best_paths.aux_labels)
+ aux_shape = best_paths.arcs.shape().remove_axis(1)
+ aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
# remove 0's and -1's.
- aux_labels = k2r.remove_values_leq(aux_labels, 0)
+ aux_labels = aux_labels.remove_values_leq(0)
- assert aux_labels.num_axes() == 2
- return k2r.to_list(aux_labels)
+ assert aux_labels.num_axes == 2
+ if return_ragged:
+ return aux_labels
+ else:
+ return aux_labels.tolist()
def store_transcripts(
diff --git a/test/test_bpe_graph_compiler.py b/test/test_bpe_graph_compiler.py
index 67d300b7d..e58c4f1c6 100755
--- a/test/test_bpe_graph_compiler.py
+++ b/test/test_bpe_graph_compiler.py
@@ -16,9 +16,10 @@
# limitations under the License.
+from pathlib import Path
+
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.lexicon import BpeLexicon
-from pathlib import Path
def test():
diff --git a/test/test_decode.py b/test/test_decode.py
new file mode 100644
index 000000000..7ef127781
--- /dev/null
+++ b/test/test_decode.py
@@ -0,0 +1,62 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+You can run this file in one of the two ways:
+
+ (1) cd icefall; pytest test/test_decode.py
+ (2) cd icefall; ./test/test_decode.py
+"""
+
+import k2
+from icefall.decode import Nbest
+
+
+def test_nbest_from_lattice():
+ s = """
+ 0 1 1 10 0.1
+ 0 1 5 10 0.11
+ 0 1 2 20 0.2
+ 1 2 3 30 0.3
+ 1 2 4 40 0.4
+ 2 3 -1 -1 0.5
+ 3
+ """
+ lattice = k2.Fsa.from_str(s, acceptor=False)
+ lattice = k2.Fsa.from_fsas([lattice, lattice])
+
+ nbest = Nbest.from_lattice(
+ lattice=lattice,
+ num_paths=10,
+ use_double_scores=True,
+ lattice_score_scale=0.5,
+ )
+ # each lattice has only 4 distinct paths that have different word sequences:
+ # 10->30
+ # 10->40
+ # 20->30
+ # 20->40
+ #
+ # So there should be only 4 paths for each lattice in the Nbest object
+ assert nbest.fsa.shape[0] == 4 * 2
+ assert nbest.shape.row_splits(1).tolist() == [0, 4, 8]
+
+ nbest2 = nbest.intersect(lattice)
+ tot_scores = nbest2.tot_scores()
+ argmax = tot_scores.argmax()
+ best_path = k2.index_fsa(nbest2.fsa, argmax)
+ print(best_path[0])
diff --git a/test/test_utils.py b/test/test_utils.py
index 2dd79689f..7ac52b289 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -60,7 +60,7 @@ def test_get_texts_ragged():
4
"""
)
- fsa1.aux_labels = k2.RaggedInt("[ [1 3 0 2] [] [4 0 1] [-1]]")
+ fsa1.aux_labels = k2.RaggedTensor("[ [1 3 0 2] [] [4 0 1] [-1]]")
fsa2 = k2.Fsa.from_str(
"""
@@ -70,7 +70,7 @@ def test_get_texts_ragged():
3
"""
)
- fsa2.aux_labels = k2.RaggedInt("[[3 0 5 0 8] [0 9 7 0] [-1]]")
+ fsa2.aux_labels = k2.RaggedTensor("[[3 0 5 0 8] [0 9 7 0] [-1]]")
fsas = k2.Fsa.from_fsas([fsa1, fsa2])
texts = get_texts(fsas)
assert texts == [[1, 3, 2, 4, 1], [3, 5, 8, 9, 7]]
@@ -108,3 +108,14 @@ def test_attribute_dict():
assert s["b"] == 20
s.c = 100
assert s["c"] == 100
+ assert hasattr(s, "a")
+ assert hasattr(s, "b")
+ assert getattr(s, "a") == 10
+ del s.a
+ assert hasattr(s, "a") is False
+ setattr(s, "c", 100)
+ s.c = 100
+ try:
+ del s.a
+ except AttributeError as ex:
+ print(f"Caught exception: {ex}")