Merge branch 'master' into whisper_zh

This commit is contained in:
Yuekai Zhang 2024-03-07 13:53:35 +07:00 committed by GitHub
commit 19e21ba3ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
220 changed files with 26487 additions and 671 deletions

View File

@ -11,6 +11,7 @@ ARG _KALDIFEAT_VERSION="${KALDIFEAT_VERSION}+cpu.torch${TORCH_VERSION}"
RUN apt-get update -y && \
apt-get install -qq -y \
cmake \
ffmpeg \
git \
git-lfs \
@ -50,6 +51,7 @@ RUN pip install --no-cache-dir \
onnxruntime \
pytest \
sentencepiece>=0.1.96 \
pypinyin==0.50.0 \
six \
tensorboard \
typeguard

View File

@ -6,8 +6,8 @@ import json
def version_gt(a, b):
a_major, a_minor = a.split(".")[:2]
b_major, b_minor = b.split(".")[:2]
a_major, a_minor = list(map(int, a.split(".")))[:2]
b_major, b_minor = list(map(int, b.split(".")))[:2]
if a_major > b_major:
return True
@ -18,8 +18,8 @@ def version_gt(a, b):
def version_ge(a, b):
a_major, a_minor = a.split(".")[:2]
b_major, b_minor = b.split(".")[:2]
a_major, a_minor = list(map(int, a.split(".")))[:2]
b_major, b_minor = list(map(int, b.split(".")))[:2]
if a_major > b_major:
return True
@ -43,11 +43,12 @@ def get_torchaudio_version(torch_version):
def get_matrix():
k2_version = "1.24.4.dev20231220"
kaldifeat_version = "1.25.3.dev20231221"
version = "1.2"
python_version = ["3.8", "3.9", "3.10", "3.11"]
k2_version = "1.24.4.dev20240223"
kaldifeat_version = "1.25.4.dev20240223"
version = "20240223"
python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"]
torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
torch_version += ["2.2.0", "2.2.1"]
matrix = []
for p in python_version:
@ -57,6 +58,10 @@ def get_matrix():
if version_gt(p, "3.10") and not version_gt(t, "2.0"):
continue
# only torch>=2.2.0 supports python 3.12
if version_gt(p, "3.11") and not version_gt(t, "2.1"):
continue
matrix.append(
{
"k2-version": k2_version,

View File

@ -16,7 +16,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
image: ["torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
image: ["torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
steps:
# refer to https://github.com/actions/checkout

View File

@ -14,13 +14,20 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
image: ["torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
image: ["torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
steps:
# refer to https://github.com/actions/checkout
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Free space
shell: bash
run: |
df -h
rm -rf /opt/hostedtoolcache
df -h
- name: Run the build process with Docker
uses: addnab/docker-run-action@v3
with:

View File

@ -49,7 +49,7 @@ jobs:
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 isort==5.10.1
# Click issue fixed in https://github.com/psf/black/pull/2966
- name: Run flake8
@ -67,3 +67,9 @@ jobs:
working-directory: ${{github.workspace}}
run: |
black --check --diff .
- name: Run isort
shell: bash
working-directory: ${{github.workspace}}
run: |
isort --check --diff .

View File

@ -59,4 +59,7 @@ jobs:
cd /icefall
git config --global --add safe.directory /icefall
python3 -m torch.utils.collect_env
python3 -m k2.version
.github/scripts/yesno/ASR/run.sh

View File

@ -26,7 +26,7 @@ repos:
# E121,E123,E126,E226,E24,E704,W503,W504
- repo: https://github.com/pycqa/isort
rev: 5.11.5
rev: 5.10.1
hooks:
- id: isort
args: ["--profile=black"]

View File

@ -74,6 +74,9 @@ The [LibriSpeech][librispeech] recipe supports the most comprehensive set of mod
- LSTM-based Predictor
- [Stateless Predictor](https://research.google/pubs/rnn-transducer-with-stateless-prediction-network/)
#### Whisper
- [OpenAi Whisper](https://arxiv.org/abs/2212.04356) (We support fine-tuning on AiShell-1.)
If you are willing to contribute to icefall, please refer to [contributing](https://icefall.readthedocs.io/en/latest/contributing/index.html) for more details.
We would like to highlight the performance of some of the recipes here.

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.7
ARG K2_VERSION="1.24.4.dev20230725+cuda11.3.torch1.12.1"
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.3.torch1.12.1"
ARG K2_VERSION="1.24.4.dev20240223+cuda11.3.torch1.12.1"
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.3.torch1.12.1"
ARG TORCHAUDIO_VERSION="0.12.1+cu113"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.9
ARG K2_VERSION="1.24.4.dev20231021+cuda11.6.torch1.13.0"
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.6.torch1.13.0"
ARG K2_VERSION="1.24.4.dev20240223+cuda11.6.torch1.13.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.6.torch1.13.0"
ARG TORCHAUDIO_VERSION="0.13.0+cu116"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.7
ARG K2_VERSION="1.24.3.dev20230726+cuda10.2.torch1.9.0"
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda10.2.torch1.9.0"
ARG K2_VERSION="1.24.4.dev20240223+cuda10.2.torch1.9.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda10.2.torch1.9.0"
ARG TORCHAUDIO_VERSION="0.9.0"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20231021+cuda11.7.torch2.0.0"
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.7.torch2.0.0"
ARG K2_VERSION="1.24.4.dev20240223+cuda11.7.torch2.0.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.7.torch2.0.0"
ARG TORCHAUDIO_VERSION="2.0.0+cu117"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20231021+cuda11.8.torch2.1.0"
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.8.torch2.1.0"
ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.1.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.1.0"
ARG TORCHAUDIO_VERSION="2.1.0+cu118"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20231021+cuda12.1.torch2.1.0"
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda12.1.torch2.1.0"
ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.1.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.1.0"
ARG TORCHAUDIO_VERSION="2.1.0+cu121"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -0,0 +1,70 @@
FROM pytorch/pytorch:2.2.0-cuda11.8-cudnn8-devel
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.2.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.2.0"
ARG TORCHAUDIO_VERSION="2.2.0+cu118"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -0,0 +1,70 @@
FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-devel
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.2.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.2.0"
ARG TORCHAUDIO_VERSION="2.2.0+cu121"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -0,0 +1,70 @@
FROM pytorch/pytorch:2.2.1-cuda11.8-cudnn8-devel
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.2.1"
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.2.1"
ARG TORCHAUDIO_VERSION="2.2.1+cu118"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -0,0 +1,70 @@
FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-devel
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.2.1"
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.2.1"
ARG TORCHAUDIO_VERSION="2.2.1+cu121"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -30,7 +30,7 @@ of langugae model integration.
First, let's have a look at some background information. As the predecessor of LODR, Density Ratio (DR) is first proposed `here <https://arxiv.org/abs/2002.11268>`_
to address the language information mismatch between the training
corpus (source domain) and the testing corpus (target domain). Assuming that the source domain and the test domain
are acoustically similar, DR derives the following formular for decoding with Bayes' theorem:
are acoustically similar, DR derives the following formula for decoding with Bayes' theorem:
.. math::
@ -41,7 +41,7 @@ are acoustically similar, DR derives the following formular for decoding with Ba
where :math:`\lambda_1` and :math:`\lambda_2` are the weights of LM scores for target domain and source domain respectively.
Here, the source domain LM is trained on the training corpus. The only difference in the above formular compared to
Here, the source domain LM is trained on the training corpus. The only difference in the above formula compared to
shallow fusion is the subtraction of the source domain LM.
Some works treat the predictor and the joiner of the neural transducer as its internal LM. However, the LM is
@ -58,7 +58,7 @@ during decoding for transducer model:
In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Compared to DR,
the only difference lies in the choice of source domain LM. According to the original `paper <https://arxiv.org/abs/2203.16776>`_,
LODR achieves similar performance compared DR in both intra-domain and cross-domain settings.
LODR achieves similar performance compared to DR in both intra-domain and cross-domain settings.
As a bi-gram is much faster to evaluate, LODR is usually much faster.
Now, we will show you how to use LODR in ``icefall``.

View File

@ -9,9 +9,9 @@ to improve the word-error-rate of a transducer model.
.. note::
This tutorial is based on the recipe
This tutorial is based on the recipe
`pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
which is a streaming transducer model trained on `LibriSpeech`_.
which is a streaming transducer model trained on `LibriSpeech`_.
However, you can easily apply shallow fusion to other recipes.
If you encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`_.
@ -69,11 +69,11 @@ Training a language model usually takes a long time, we can download a pre-train
.. code-block:: bash
$ # download the external LM
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
$ # create a symbolic link so that the checkpoint can be loaded
$ pushd icefall-librispeech-rnn-lm/exp
$ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt
$ ln -s pretrained.pt epoch-99.pt
$ popd
.. note::
@ -85,7 +85,7 @@ Training a language model usually takes a long time, we can download a pre-train
To use shallow fusion for decoding, we can execute the following command:
.. code-block:: bash
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ lm_dir=./icefall-librispeech-rnn-lm/exp
$ lm_scale=0.29
@ -133,16 +133,16 @@ The decoding result obtained with the above command are shown below.
$ For test-other, WER of different settings are:
$ beam_size_4 7.08 best for test-other
The improvement of shallow fusion is very obvious! The relative WER reduction on test-other is around 10.5%.
The improvement of shallow fusion is very obvious! The relative WER reduction on test-other is around 10.5%.
A few parameters can be tuned to further boost the performance of shallow fusion:
- ``--lm-scale``
- ``--lm-scale``
Controls the scale of the LM. If too small, the external language model may not be fully utilized; if too large,
the LM score may dominant during decoding, leading to bad WER. A typical value of this is around 0.3.
Controls the scale of the LM. If too small, the external language model may not be fully utilized; if too large,
the LM score might be dominant during decoding, leading to bad WER. A typical value of this is around 0.3.
- ``--beam-size``
- ``--beam-size``
The number of active paths in the search beam. It controls the trade-off between decoding efficiency and accuracy.
Here, we also show how `--beam-size` effect the WER and decoding time:
@ -176,4 +176,4 @@ As we see, a larger beam size during shallow fusion improves the WER, but is als

View File

@ -34,6 +34,10 @@ which will give you something like below:
.. code-block:: bash
"torch2.2.1-cuda12.1"
"torch2.2.1-cuda11.8"
"torch2.2.0-cuda12.1"
"torch2.2.0-cuda11.8"
"torch2.1.0-cuda12.1"
"torch2.1.0-cuda11.8"
"torch2.0.0-cuda11.7"

View File

@ -0,0 +1,140 @@
Finetune from a supervised pre-trained Zipformer model
======================================================
This tutorial shows you how to fine-tune a supervised pre-trained **Zipformer**
transducer model on a new dataset.
.. HINT::
We assume you have read the page :ref:`install icefall` and have setup
the environment for ``icefall``.
.. HINT::
We recommend you to use a GPU or several GPUs to run this recipe
For illustration purpose, we fine-tune the Zipformer transducer model
pre-trained on `LibriSpeech`_ on the small subset of `GigaSpeech`_. You could use your
own data for fine-tuning if you create a manifest for your new dataset.
Data preparation
----------------
Please follow the instructions in the `GigaSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR>`_
to prepare the fine-tune data used in this tutorial. We only require the small subset in GigaSpeech for this tutorial.
Model preparation
-----------------
We are using the Zipformer model trained on full LibriSpeech (960 hours) as the intialization. The
checkpoint of the model can be downloaded via the following command:
.. code-block:: bash
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
$ cd icefall-asr-librispeech-zipformer-2023-05-15/exp
$ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt
$ cd ../data/lang_bpe_500
$ git lfs pull --include bpe.model
$ cd ../../..
Before fine-tuning, let's test the model's WER on the new domain. The following command performs
decoding on the GigaSpeech test sets:
.. code-block:: bash
./zipformer/decode_gigaspeech.py \
--epoch 99 \
--avg 1 \
--exp-dir icefall-asr-librispeech-zipformer-2023-05-15/exp \
--use-averaged-model 0 \
--max-duration 1000 \
--decoding-method greedy_search
You should see the following numbers:
.. code-block::
For dev, WER of different settings are:
greedy_search 20.06 best for dev
For test, WER of different settings are:
greedy_search 19.27 best for test
Fine-tune
---------
Since LibriSpeech and GigaSpeech are both English dataset, we can initialize the whole
Zipformer model with the checkpoint downloaded in the previous step (otherwise we should consider
initializing the stateless decoder and joiner from scratch due to the mismatch of the output
vocabulary). The following command starts a fine-tuning experiment:
.. code-block:: bash
$ use_mux=0
$ do_finetune=1
$ ./zipformer/finetune.py \
--world-size 2 \
--num-epochs 20 \
--start-epoch 1 \
--exp-dir zipformer/exp_giga_finetune${do_finetune}_mux${use_mux} \
--use-fp16 1 \
--base-lr 0.0045 \
--bpe-model data/lang_bpe_500/bpe.model \
--do-finetune $do_finetune \
--use-mux $use_mux \
--master-port 13024 \
--finetune-ckpt icefall-asr-librispeech-zipformer-2023-05-15/exp/pretrained.pt \
--max-duration 1000
The following arguments are related to fine-tuning:
- ``--base-lr``
The learning rate used for fine-tuning. We suggest to set a **small** learning rate for fine-tuning,
otherwise the model may forget the initialization very quickly. A reasonable value should be around
1/10 of the original lr, i.e 0.0045.
- ``--do-finetune``
If True, do fine-tuning by initializing the model from a pre-trained checkpoint.
**Note that if you want to resume your fine-tuning experiment from certain epochs, you
need to set this to False.**
- ``--finetune-ckpt``
The path to the pre-trained checkpoint (used for initialization).
- ``--use-mux``
If True, mix the fine-tune data with the original training data by using `CutSet.mux <https://lhotse.readthedocs.io/en/latest/api.html#lhotse.supervision.SupervisionSet.mux>`_
This helps maintain the model's performance on the original domain if the original training
is available. **If you don't have the original training data, please set it to False.**
After fine-tuning, let's test the WERs. You can do this via the following command:
.. code-block:: bash
$ use_mux=0
$ do_finetune=1
$ ./zipformer/decode_gigaspeech.py \
--epoch 20 \
--avg 10 \
--exp-dir zipformer/exp_giga_finetune${do_finetune}_mux${use_mux} \
--use-averaged-model 1 \
--max-duration 1000 \
--decoding-method greedy_search
You should see numbers similar to the ones below:
.. code-block:: text
For dev, WER of different settings are:
greedy_search 13.47 best for dev
For test, WER of different settings are:
greedy_search 13.66 best for test
Compared to the original checkpoint, the fine-tuned model achieves much lower WERs
on the GigaSpeech test sets.

View File

@ -0,0 +1,15 @@
Fine-tune a pre-trained model
=============================
After pre-training on public available datasets, the ASR model is already capable of
performing general speech recognition with relatively high accuracy. However, the accuracy
could be still low on certain domains that are quite different from the original training
set. In this case, we can fine-tune the model with a small amount of additional labelled
data to improve the performance on new domains.
.. toctree::
:maxdepth: 2
:caption: Table of Contents
from_supervised/finetune_zipformer

View File

@ -1,11 +1,11 @@
VITS
VITS-LJSpeech
===============
This tutorial shows you how to train an VITS model
with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset.
.. note::
TTS related recipes require packages in ``requirements-tts.txt``.
.. note::
@ -120,4 +120,4 @@ Download pretrained models
If you don't want to train from scratch, you can download the pretrained models
by visiting the following link:
- `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29>`_
- `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_

View File

@ -1,11 +1,11 @@
VITS
VITS-VCTK
===============
This tutorial shows you how to train an VITS model
with the `VCTK <https://datashare.ed.ac.uk/handle/10283/3443>`_ dataset.
.. note::
TTS related recipes require packages in ``requirements-tts.txt``.
.. note::

View File

@ -17,3 +17,4 @@ We may add recipes for other tasks as well in the future.
Streaming-ASR/index
RNN-LM/index
TTS/index
Finetune/index

View File

@ -19,7 +19,7 @@
Usage:
#fine-tuning with deepspeed zero stage 1
torchrun --nproc-per-node 8 ./whisper/train.py \
torchrun --nproc_per_node 8 ./whisper/train.py \
--max-duration 200 \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
@ -28,7 +28,7 @@ torchrun --nproc-per-node 8 ./whisper/train.py \
--deepspeed_config ./whisper/ds_config_zero1.json
# fine-tuning with ddp
torchrun --nproc-per-node 8 ./whisper/train.py \
torchrun --nproc_per_node 8 ./whisper/train.py \
--max-duration 200 \
--exp-dir whisper/exp_medium \
--manifest-dir data/fbank_whisper \
@ -136,7 +136,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
default="whisper/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved

View File

@ -79,10 +79,10 @@ It will generate the following 3 files inside $repo/exp:
import argparse
import logging
from icefall import is_module_available
import torch
from onnx_pretrained import OnnxModel
import torch
from icefall import is_module_available
def get_parser():

View File

@ -70,9 +70,9 @@ import logging
from pathlib import Path
import torch
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from scaling_converter import convert_scaled_to_non_scaled
from tokenizer import Tokenizer
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,

View File

@ -0,0 +1,9 @@
## Fluent Speech Commands recipe
This is a recipe for the Fluent Speech Commands dataset, a speech dataset which transcribes short utterances (such as "turn the lights on in the kitchen") into action frames (such as {"action": "activate", "object": "lights", "location": "kitchen"}). The training set contains 23,132 utterances, whereas the test set contains 3793 utterances.
Dataset Paper link: <https://paperswithcode.com/dataset/fluent-speech-commands>
cd icefall/egs/fluent_speech_commands/
Training: python transducer/train.py
Decoding: python transducer/decode.py

View File

@ -0,0 +1,136 @@
#!/usr/bin/env python3
"""
This script takes as input lang_dir and generates HLG from
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G.fst.txt
The generated HLG is saved in $lang_dir/HLG.pt
"""
import argparse
import logging
from pathlib import Path
import k2
import torch
from icefall.lexicon import Lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
return parser.parse_args()
def compile_HLG(lang_dir: str) -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
Return:
An FSA representing HLG.
"""
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
logging.info("Loading G.fst.txt")
with open(lang_dir / "G.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
L = k2.arc_sort(L)
G = k2.arc_sort(G)
logging.info("Intersecting L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}")
logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(f"LG shape after k2.connect: {LG.shape}")
logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")
LG = k2.determinize(LG)
logging.info(type(LG.aux_labels))
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
logging.info("Removing disambiguation symbols on LG")
# LG.labels[LG.labels >= first_token_disambig_id] = 0
# see https://github.com/k2-fsa/k2/pull/1140
labels = LG.labels
labels[labels >= first_token_disambig_id] = 0
LG.labels = labels
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
logging.info("Composing H and LG")
# CAUTION: The name of the inner_labels is fixed
# to `tokens`. If you want to change it, please
# also change other places in icefall that are using
# it.
HLG = k2.compose(H, LG, inner_labels="tokens")
logging.info("Connecting LG")
HLG = k2.connect(HLG)
logging.info("Arc sorting LG")
HLG = k2.arc_sort(HLG)
logging.info(f"HLG.shape: {HLG.shape}")
return HLG
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
if (lang_dir / "HLG.pt").is_file():
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
return
logging.info(f"Processing {lang_dir}")
HLG = compile_HLG(lang_dir)
logging.info(f"Saving HLG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,97 @@
#!/usr/bin/env python3
"""
This file computes fbank features of the Fluent Speech Commands dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or it wastes a
# lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_slu(manifest_dir, fbanks_dir):
src_dir = Path(manifest_dir)
output_dir = Path(fbanks_dir)
# This dataset is rather small, so we use only one job
num_jobs = min(1, os.cpu_count())
num_mel_bins = 23
dataset_parts = (
"train",
"valid",
"test",
)
prefix = "slu"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(sampling_rate=16000, num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
cuts_file = output_dir / f"{prefix}_cuts_{partition}.{suffix}"
if cuts_file.is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 1, # use one job
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(cuts_file)
parser = argparse.ArgumentParser()
parser.add_argument("manifest_dir")
parser.add_argument("fbanks_dir")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
args = parser.parse_args()
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_slu(args.manifest_dir, args.fbanks_dir)

View File

@ -0,0 +1,59 @@
import argparse
import pandas
from tqdm import tqdm
def generate_lexicon(corpus_dir, lm_dir):
data = pandas.read_csv(
str(corpus_dir) + "/data/train_data.csv", index_col=0, header=0
)
vocab_transcript = set()
vocab_frames = set()
transcripts = data["transcription"].tolist()
frames = list(
i
for i in zip(
data["action"].tolist(), data["object"].tolist(), data["location"].tolist()
)
)
for transcript in tqdm(transcripts):
for word in transcript.split():
vocab_transcript.add(word)
for frame in tqdm(frames):
for word in frame:
vocab_frames.add("_".join(word.split()))
with open(lm_dir + "/words_transcript.txt", "w") as lexicon_transcript_file:
lexicon_transcript_file.write("<UNK> 1" + "\n")
lexicon_transcript_file.write("<s> 2" + "\n")
lexicon_transcript_file.write("</s> 0" + "\n")
id = 3
for vocab in vocab_transcript:
lexicon_transcript_file.write(vocab + " " + str(id) + "\n")
id += 1
with open(lm_dir + "/words_frames.txt", "w") as lexicon_frames_file:
lexicon_frames_file.write("<UNK> 1" + "\n")
lexicon_frames_file.write("<s> 2" + "\n")
lexicon_frames_file.write("</s> 0" + "\n")
id = 3
for vocab in vocab_frames:
lexicon_frames_file.write(vocab + " " + str(id) + "\n")
id += 1
parser = argparse.ArgumentParser()
parser.add_argument("corpus_dir")
parser.add_argument("lm_dir")
def main():
args = parser.parse_args()
generate_lexicon(args.corpus_dir, args.lm_dir)
main()

View File

@ -0,0 +1,371 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
"""
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
consisting of words and tokens (i.e., phones) and does the following:
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
2. Generate tokens.txt, the token table mapping a token to a unique integer.
3. Generate words.txt, the word table mapping a word to a unique integer.
4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt")
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.
"""
import argparse
import math
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
import k2
import torch
from icefall.lexicon import read_lexicon, write_lexicon
Lexicon = List[Tuple[str, List[str]]]
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.
Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.
Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items():
f.write(f"{sym} {i}\n")
def get_tokens(lexicon: Lexicon) -> List[str]:
"""Get tokens from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique tokens.
"""
ans = set()
for _, tokens in lexicon:
ans.update(tokens)
sorted_ans = sorted(list(ans))
return sorted_ans
def get_words(lexicon: Lexicon) -> List[str]:
"""Get words from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique words.
"""
ans = set()
for word, _ in lexicon:
ans.add(word)
sorted_ans = sorted(list(ans))
return sorted_ans
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
"""It adds pseudo-token disambiguation symbols #1, #2 and so on
at the ends of tokens to ensure that all pronunciations are different,
and that none is a prefix of another.
See also add_lex_disambig.pl from kaldi.
Args:
lexicon:
It is returned by :func:`read_lexicon`.
Returns:
Return a tuple with two elements:
- The output lexicon with disambiguation symbols
- The ID of the max disambiguation symbol that appears
in the lexicon
"""
# (1) Work out the count of each token-sequence in the
# lexicon.
count = defaultdict(int)
for _, tokens in lexicon:
count[" ".join(tokens)] += 1
# (2) For each left sub-sequence of each token-sequence, note down
# that it exists (for identifying prefixes of longer strings).
issubseq = defaultdict(int)
for _, tokens in lexicon:
tokens = tokens.copy()
tokens.pop()
while tokens:
issubseq[" ".join(tokens)] = 1
tokens.pop()
# (3) For each entry in the lexicon:
# if the token sequence is unique and is not a
# prefix of another word, no disambig symbol.
# Else output #1, or #2, #3, ... if the same token-seq
# has already been assigned a disambig symbol.
ans = []
# We start with #1 since #0 has its own purpose
first_allowed_disambig = 1
max_disambig = first_allowed_disambig - 1
last_used_disambig_symbol_of = defaultdict(int)
for word, tokens in lexicon:
tokenseq = " ".join(tokens)
assert tokenseq != ""
if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
ans.append((word, tokens))
continue
cur_disambig = last_used_disambig_symbol_of[tokenseq]
if cur_disambig == 0:
cur_disambig = first_allowed_disambig
else:
cur_disambig += 1
if cur_disambig > max_disambig:
max_disambig = cur_disambig
last_used_disambig_symbol_of[tokenseq] = cur_disambig
tokenseq += f" #{cur_disambig}"
ans.append((word, tokenseq.split()))
return ans, max_disambig
def generate_id_map(symbols: List[str]) -> Dict[str, int]:
"""Generate ID maps, i.e., map a symbol to a unique ID.
Args:
symbols:
A list of unique symbols.
Returns:
A dict containing the mapping between symbols and IDs.
"""
return {sym: i for i, sym in enumerate(symbols)}
def add_self_loops(
arcs: List[List[Any]], disambig_token: int, disambig_word: int
) -> List[List[Any]]:
"""Adds self-loops to states of an FST to propagate disambiguation symbols
through it. They are added on each state with non-epsilon output symbols
on at least one arc out of the state.
See also fstaddselfloops.pl from Kaldi. One difference is that
Kaldi uses OpenFst style FSTs and it has multiple final states.
This function uses k2 style FSTs and it does not need to add self-loops
to the final state.
The input label of a self-loop is `disambig_token`, while the output
label is `disambig_word`.
Args:
arcs:
A list-of-list. The sublist contains
`[src_state, dest_state, label, aux_label, score]`
disambig_token:
It is the token ID of the symbol `#0`.
disambig_word:
It is the word ID of the symbol `#0`.
Return:
Return new `arcs` containing self-loops.
"""
states_needs_self_loops = set()
for arc in arcs:
src, dst, ilabel, olabel, score = arc
if olabel != 0:
states_needs_self_loops.add(src)
ans = []
for s in states_needs_self_loops:
ans.append([s, s, disambig_token, disambig_word, 0])
return arcs + ans
def lexicon_to_fst(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
sil_token: str = "!SIL",
sil_prob: float = 0.5,
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format) with optional silence at
the beginning and end of each word.
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
sil_token:
The silence token.
sil_prob:
The probability for adding a silence at the beginning and end
of the word.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
assert sil_prob > 0.0 and sil_prob < 1.0
# CAUTION: we use score, i.e, negative cost.
sil_score = math.log(sil_prob)
no_sil_score = math.log(1.0 - sil_prob)
start_state = 0
loop_state = 1 # words enter and leave from here
sil_state = 2 # words terminate here when followed by silence; this state
# has a silence transition to loop_state.
next_state = 3 # the next un-allocated state, will be incremented as we go.
arcs = []
# assert token2id["<eps>"] == 0
# assert word2id["<eps>"] == 0
eps = 0
sil_token = word2id[sil_token]
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
arcs.append([start_state, sil_state, eps, eps, sil_score])
arcs.append([sil_state, loop_state, sil_token, eps, 0])
for word, tokens in lexicon:
assert len(tokens) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
tokens = [word2id[i] for i in tokens]
for i in range(len(tokens) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, tokens[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last token of this word
# It has two out-going arcs, one to the loop state,
# the other one to the sil_state.
i = len(tokens) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
if need_self_loops:
disambig_token = word2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
parser = argparse.ArgumentParser()
parser.add_argument("lm_dir")
def main():
args = parser.parse_args()
out_dir = Path(args.lm_dir)
lexicon_filenames = [out_dir / "words_frames.txt", out_dir / "words_transcript.txt"]
names = ["frames", "transcript"]
sil_token = "!SIL"
sil_prob = 0.5
for name, lexicon_filename in zip(names, lexicon_filenames):
lexicon = read_lexicon(lexicon_filename)
tokens = get_words(lexicon)
words = get_words(lexicon)
new_lexicon = []
for lexicon_item in lexicon:
new_lexicon.append((lexicon_item[0], [lexicon_item[0]]))
lexicon = new_lexicon
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in tokens
tokens.append(f"#{i}")
tokens = ["<eps>"] + tokens
words = ["eps"] + words + ["#0", "!SIL"]
token2id = generate_id_map(tokens)
word2id = generate_id_map(words)
write_mapping(out_dir / ("tokens_" + name + ".txt"), token2id)
write_mapping(out_dir / ("words_" + name + ".txt"), word2id)
write_lexicon(out_dir / ("lexicon_disambig_" + name + ".txt"), lexicon_disambig)
L = lexicon_to_fst(
lexicon,
token2id=word2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
)
L_disambig = lexicon_to_fst(
lexicon_disambig,
token2id=word2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
need_self_loops=True,
)
torch.save(L.as_dict(), out_dir / ("L_" + name + ".pt"))
torch.save(L_disambig.as_dict(), out_dir / ("L_disambig_" + name + ".pt"))
if False:
# Just for debugging, will remove it
L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
L_disambig.labels_sym = L.labels_sym
L_disambig.aux_labels_sym = L.aux_labels_sym
L.draw(out_dir / "L.png", title="L")
L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
main()

View File

@ -0,0 +1,103 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
stage=1
stop_stage=5
data_dir=path/to/fluent/speech/commands
target_root_dir=data/
lang_dir=${target_root_dir}/lang_phone
lm_dir=${target_root_dir}/lm
manifest_dir=${target_root_dir}/manifests
fbanks_dir=${target_root_dir}/fbanks
. shared/parse_options.sh || exit 1
mkdir -p $lang_dir
mkdir -p $lm_dir
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "data_dir: $data_dir"
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare slu manifest"
mkdir -p $manifest_dir
lhotse prepare slu $data_dir $manifest_dir
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Compute fbank for SLU"
mkdir -p $fbanks_dir
python ./local/compute_fbank_slu.py $manifest_dir $fbanks_dir
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare lang"
# NOTE: "<UNK> SIL" is added for implementation convenience
# as the graph compiler code requires that there is a OOV word
# in the lexicon.
python ./local/generate_lexicon.py $data_dir $lm_dir
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Train LM"
# We use a unigram G
./shared/make_kn_lm.py \
-ngram-order 1 \
-text $lm_dir/words_transcript.txt \
-lm $lm_dir/G_transcript.arpa
./shared/make_kn_lm.py \
-ngram-order 1 \
-text $lm_dir/words_frames.txt \
-lm $lm_dir/G_frames.arpa
python ./local/prepare_lang.py $lm_dir
if [ ! -f $lm_dir/G_transcript.fst.txt ]; then
python -m kaldilm \
--read-symbol-table="$lm_dir/words_transcript.txt" \
$lm_dir/G_transcript.arpa > $lm_dir/G_transcript.fst.txt
fi
if [ ! -f $lm_dir/G_frames.fst.txt ]; then
python -m kaldilm \
--read-symbol-table="$lm_dir/words_frames.txt" \
$lm_dir/G_frames.arpa > $lm_dir/G_frames.fst.txt
fi
mkdir -p $lm_dir/frames
mkdir -p $lm_dir/transcript
chmod -R +777 .
for i in G_frames.arpa G_frames.fst.txt L_disambig_frames.pt L_frames.pt lexicon_disambig_frames.txt tokens_frames.txt words_frames.txt;
do
j=${i//"_frames"/}
mv "$lm_dir/$i" $lm_dir/frames/$j
done
for i in G_transcript.arpa G_transcript.fst.txt L_disambig_transcript.pt L_transcript.pt lexicon_disambig_transcript.txt tokens_transcript.txt words_transcript.txt;
do
j=${i//"_transcript"/}
mv "$lm_dir/$i" $lm_dir/transcript/$j
done
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compile HLG"
./local/compile_hlg.py --lang-dir $lm_dir/frames
./local/compile_hlg.py --lang-dir $lm_dir/transcript
fi

View File

@ -0,0 +1 @@
../../../icefall/shared/

View File

@ -0,0 +1,71 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import torch
from transducer.model import Transducer
def greedy_search(
model: Transducer, encoder_out: torch.Tensor, id2word: dict
) -> List[str]:
"""
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1)
t = 0
hyp = []
max_u = 1000 # terminate after this number of steps
u = 0
while t < T and u < max_u:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
logits = model.joiner(current_encoder_out, decoder_out)
log_prob = logits.log_softmax(dim=-1)
# log_prob is (N, 1, 1)
# TODO: Use logits.argmax()
y = log_prob.argmax()
if y != blank_id:
hyp.append(y.item())
y = y.reshape(1, 1)
decoder_out, (h, c) = model.decoder(y, (h, c))
u += 1
else:
t += 1
# id2word = {1: "YES", 2: "NO"}
hyp = [id2word[i] for i in hyp]
return hyp

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer_stateless/conformer.py

View File

@ -0,0 +1,346 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import torch
import torch.nn as nn
from transducer.beam_search import greedy_search
from transducer.conformer import Conformer
from transducer.decoder import Decoder
from transducer.joiner import Joiner
from transducer.model import Transducer
from transducer.slu_datamodule import SluDataModule
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_id2word(params):
id2word = {}
# 0 is blank
id = 1
try:
with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
for line in lexicon_file:
if len(line.strip()) > 0:
id2word[id] = line.split()[0]
id += 1
except:
pass
return id2word
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=6,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer/exp",
help="Directory from which to load the checkpoints",
)
parser.add_argument("--lang-dir", type=str, default="data/lm/frames")
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 23,
"lang_dir": Path("data/lm/frames"),
# encoder/decoder params
"vocab_size": 3, # blank, yes, no
"blank_id": 0,
"embedding_dim": 32,
"hidden_dim": 16,
"num_decoder_layers": 4,
}
)
vocab_size = 1
with open(params.lang_dir / "lexicon_disambig.txt") as lexicon_file:
for line in lexicon_file:
if (
len(line.strip()) > 0
): # and '<UNK>' not in line and '<s>' not in line and '</s>' not in line:
vocab_size += 1
params.vocab_size = vocab_size
return params
def decode_one_batch(
params: AttributeDict, model: nn.Module, batch: dict, id2word: dict
) -> List[List[int]]:
"""Decode one batch and return the result in a list-of-list.
Each sub list contains the word IDs for an utterance in the batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding.
- params.method is "nbest", it uses nbest decoding.
model:
The neural model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
(https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py)
Returns:
Return the decoding result. `len(ans)` == batch size.
"""
device = model.device
feature = batch["inputs"]
feature = feature.to(device)
# at entry, feature is (N, T, C)
feature_lens = batch["supervisions"]["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
hyp = greedy_search(model=model, encoder_out=encoder_out_i, id2word=id2word)
hyps.append(hyp)
# hyps = [[word_table[i] for i in ids] for ids in hyps]
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
) -> List[Tuple[List[int], List[int]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
Returns:
Return a tuple contains two elements (ref_text, hyp_text):
The first is the reference transcript, and the second is the
predicted result.
"""
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
id2word = get_id2word(params)
results = []
for batch_idx, batch in enumerate(dl):
texts = [
" ".join(a.supervisions[0].custom["frames"])
for a in batch["supervisions"]["cut"]
]
texts = [
"<s> " + a.replace("change language", "change_language") + " </s>"
for a in texts
]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps = decode_one_batch(
params=params, model=model, batch=batch, id2word=id2word
)
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results.extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
exp_dir: Path,
test_set_name: str,
results: List[Tuple[List[int], List[int]]],
) -> None:
"""Save results to `exp_dir`.
Args:
exp_dir:
The output directory. This function create the following files inside
this directory:
- recogs-{test_set_name}.text
It contains the reference and hypothesis results, like below::
ref=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES']
hyp=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES']
ref=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES']
hyp=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES']
- errs-{test_set_name}.txt
It contains the detailed WER.
test_set_name:
The name of the test set, which will be part of the result filename.
results:
A list of tuples, each of which contains (ref_words, hyp_words).
Returns:
Return None.
"""
recog_path = exp_dir / f"recogs-{test_set_name}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = exp_dir / f"errs-{test_set_name}.txt"
with open(errs_filename, "w") as f:
write_error_stats(f, f"{test_set_name}", results)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
def get_transducer_model(params: AttributeDict):
# encoder = Tdnn(
# num_features=params.feature_dim,
# output_dim=params.hidden_dim,
# )
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.hidden_dim,
)
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.hidden_dim,
embedding_dropout=0.4,
rnn_dropout=0.4,
)
joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size)
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
return transducer
@torch.no_grad()
def main():
parser = get_parser()
SluDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params["env_info"] = get_env_info()
setup_logger(f"{params.exp_dir}/log/log-decode")
logging.info("Decoding started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = get_transducer_model(params)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to(device)
model.eval()
model.device = device
# we need cut ids to display recognition results.
args.return_cuts = True
slu = SluDataModule(args)
test_dl = slu.test_dataloaders()
results = decode_dataset(
dl=test_dl,
params=params,
model=model,
)
test_set_name = str(args.feature_dir).split("/")[-2]
save_results(exp_dir=params.exp_dir, test_set_name=test_set_name, results=results)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../yesno/ASR/transducer/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer_stateless/encoder_interface.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer/model.py

View File

@ -0,0 +1,289 @@
# Copyright 2021 Piotr Żelasko
# 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import List
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SimpleCutSampler,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
class SluDataModule(DataModule):
"""
DataModule for k2 ASR experiments.
It assumes there is always one train dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
"""
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
super().add_arguments(parser)
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--feature-dir",
type=Path,
default=Path("data/fbanks"),
help="Path to directory with train/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=30.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=False,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=10,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
def train_dataloaders(self) -> DataLoader:
logging.info("About to get train cuts")
cuts_train = self.train_cuts()
logging.info("About to create train dataset")
transforms = []
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
FbankConfig(sampling_rate=8000, num_mel_bins=23)
),
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=True,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=True,
)
return train_dl
def valid_dataloaders(self) -> DataLoader:
logging.info("About to get valid cuts")
cuts_valid = self.valid_cuts()
logging.debug("About to create valid dataset")
valid = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create valid dataloader")
valid_dl = DataLoader(
valid,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
persistent_workers=True,
)
return valid_dl
def test_dataloaders(self) -> DataLoader:
logging.info("About to get test cuts")
cuts_test = self.test_cuts()
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts_test,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
persistent_workers=True,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest_lazy(
self.args.feature_dir / "slu_cuts_train.jsonl.gz"
)
return cuts_train
@lru_cache()
def valid_cuts(self) -> List[CutSet]:
logging.info("About to get valid cuts")
cuts_valid = load_manifest_lazy(
self.args.feature_dir / "slu_cuts_valid.jsonl.gz"
)
return cuts_valid
@lru_cache()
def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts")
cuts_test = load_manifest_lazy(self.args.feature_dir / "slu_cuts_test.jsonl.gz")
return cuts_test

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer_stateless/subsampling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer/test_conformer.py

View File

@ -0,0 +1 @@
../../../yesno/ASR/transducer/test_decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer/test_joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer/test_transducer.py

View File

@ -0,0 +1,625 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import List, Optional, Tuple
import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from lhotse.utils import fix_random_seed
from slu_datamodule import SluDataModule
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from transducer.conformer import Conformer
# from torch.utils.tensorboard import SummaryWriter
from transducer.decoder import Decoder
from transducer.joiner import Joiner
from transducer.model import Transducer
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_word2id(params):
word2id = {}
# 0 is blank
id = 1
with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
for line in lexicon_file:
if len(line.strip()) > 0:
word2id[line.split()[0]] = id
id += 1
return word2id
def get_labels(texts: List[str], word2id) -> k2.RaggedTensor:
"""
Args:
texts:
A list of transcripts.
Returns:
Return a ragged tensor containing the corresponding word ID.
"""
# blank is 0
word_ids = []
for t in texts:
words = t.split()
ids = [word2id[w] for w in words]
word_ids.append(ids)
return k2.RaggedTensor(word_ids)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=7,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
tdnn/exp/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer/exp",
help="Directory to save results",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument("--lang-dir", type=str, default="data/lm/frames")
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
is saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- lr: It specifies the initial learning rate
- feature_dim: The model input dim. It has to match the one used
in computing features.
- weight_decay: The weight_decay for the optimizer.
- subsampling_factor: The subsampling factor for the model.
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
and continue training from that checkpoint.
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- valid_interval: Run validation if batch_idx % valid_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
"""
params = AttributeDict(
{
"lr": 1e-4,
"feature_dim": 23,
"weight_decay": 1e-6,
"start_epoch": 0,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 100,
"reset_interval": 20,
"valid_interval": 3000,
"exp_dir": Path("transducer/exp"),
"lang_dir": Path("data/lm/frames"),
# encoder/decoder params
"vocab_size": 3, # blank, yes, no
"blank_id": 0,
"embedding_dim": 32,
"hidden_dim": 16,
"num_decoder_layers": 4,
}
)
vocab_size = 1
with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
for line in lexicon_file:
if (
len(line.strip()) > 0
): # and '<UNK>' not in line and '<s>' not in line and '</s>' not in line:
vocab_size += 1
params.vocab_size = vocab_size
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict, model: nn.Module, batch: dict, is_training: bool, word2ids
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute RNN-T loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Tdnn in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = model.device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
feature_lens = batch["supervisions"]["num_frames"].to(device)
texts = [
" ".join(a.supervisions[0].custom["frames"])
for a in batch["supervisions"]["cut"]
]
texts = [
"<s> " + a.replace("change language", "change_language") + " </s>"
for a in texts
]
labels = get_labels(texts, word2ids).to(device)
with torch.set_grad_enabled(is_training):
loss = model(x=feature, x_lens=feature_lens, y=labels)
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = feature.size(0)
info["loss"] = loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
word2ids,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
is_training=False,
word2ids=word2ids,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
word2ids,
tb_writer: None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss(
params=params, model=model, batch=batch, is_training=True, word2ids=word2ids
)
# summary stats.
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
valid_info = compute_validation_loss(
params=params,
model=model,
valid_dl=valid_dl,
world_size=world_size,
word2ids=word2ids,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
if tb_writer is not None:
valid_info.write_summary(
tb_writer,
"train/valid_",
params.batch_idx_train,
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def get_transducer_model(params: AttributeDict):
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.hidden_dim,
)
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.hidden_dim,
embedding_dropout=0.4,
rnn_dropout=0.4,
)
joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size)
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
return transducer
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
params["env_info"] = get_env_info()
word2ids = get_word2id(params)
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
# if args.tensorboard and rank == 0:
# tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
# else:
# tb_writer = None
tb_writer = None
if torch.cuda.is_available():
device = torch.device("cuda", rank)
else:
device = torch.device("cpu")
logging.info(f"device: {device}")
model = get_transducer_model(params)
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = optim.Adam(
model.parameters(),
lr=params.lr,
weight_decay=params.weight_decay,
)
if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"])
slu = SluDataModule(args)
train_dl = slu.train_dataloaders()
# There are only 60 waves: 30 files are used for training
# and the remaining 30 files are used for testing.
# We use test data as validation.
valid_dl = slu.test_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
word2ids=word2ids,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
scheduler=None,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
SluDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer_stateless/transformer.py

View File

@ -30,15 +30,15 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_gigaspeech_dev_test():
def compute_fbank_gigaspeech():
in_out_dir = Path("data/fbank")
# number of workers in dataloader
num_workers = 20
# number of seconds in a batch
batch_duration = 600
batch_duration = 1000
subsets = ("DEV", "TEST")
subsets = ("L", "M", "S", "XS", "DEV", "TEST")
device = torch.device("cpu")
if torch.cuda.is_available():
@ -48,12 +48,12 @@ def compute_fbank_gigaspeech_dev_test():
logging.info(f"device: {device}")
for partition in subsets:
cuts_path = in_out_dir / f"cuts_{partition}.jsonl.gz"
cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
raw_cuts_path = in_out_dir / f"cuts_{partition}_raw.jsonl.gz"
raw_cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
@ -62,7 +62,7 @@ def compute_fbank_gigaspeech_dev_test():
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{in_out_dir}/feats_{partition}",
storage_path=f"{in_out_dir}/gigaspeech_feats_{partition}",
num_workers=num_workers,
batch_duration=batch_duration,
overwrite=True,
@ -80,7 +80,7 @@ def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_gigaspeech_dev_test()
compute_fbank_gigaspeech()
if __name__ == "__main__":

View File

@ -76,7 +76,7 @@ def get_parser():
def compute_fbank_gigaspeech_splits(args):
num_splits = args.num_splits
output_dir = "data/fbank/XL_split"
output_dir = f"data/fbank/XL_split"
output_dir = Path(output_dir)
assert output_dir.exists(), f"{output_dir} does not exist!"
@ -96,15 +96,15 @@ def compute_fbank_gigaspeech_splits(args):
logging.info(f"device: {device}")
for i in range(start, stop):
idx = f"{i + 1}".zfill(num_digits)
idx = f"{i}".zfill(num_digits)
logging.info(f"Processing {idx}/{num_splits}")
cuts_path = output_dir / f"cuts_XL.{idx}.jsonl.gz"
cuts_path = output_dir / f"gigaspeech_cuts_XL.{idx}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
raw_cuts_path = output_dir / f"cuts_XL_raw.{idx}.jsonl.gz"
raw_cuts_path = output_dir / f"gigaspeech_cuts_XL_raw.{idx}.jsonl.gz"
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
@ -113,7 +113,7 @@ def compute_fbank_gigaspeech_splits(args):
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/feats_XL_{idx}",
storage_path=f"{output_dir}/gigaspeech_feats_{idx}",
num_workers=args.num_workers,
batch_duration=args.batch_duration,
overwrite=True,

View File

@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import re
from pathlib import Path
@ -23,10 +24,24 @@ from pathlib import Path
from lhotse import CutSet, SupervisionSegment
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import str2bool
# Similar text filtering and normalization procedure as in:
# https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=False,
help="Whether to use speed perturbation.",
)
return parser.parse_args()
def normalize_text(
utt: str,
punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"),
@ -42,7 +57,7 @@ def has_no_oov(
return oov_pattern.search(sup.text) is None
def preprocess_giga_speech():
def preprocess_giga_speech(args):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
output_dir.mkdir(exist_ok=True)
@ -51,6 +66,10 @@ def preprocess_giga_speech():
"DEV",
"TEST",
"XL",
"L",
"M",
"S",
"XS",
)
logging.info("Loading manifest (may take 4 minutes)")
@ -71,7 +90,7 @@ def preprocess_giga_speech():
for partition, m in manifests.items():
logging.info(f"Processing {partition}")
raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"
raw_cuts_path = output_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
if raw_cuts_path.is_file():
logging.info(f"{partition} already exists - skipping")
continue
@ -94,11 +113,14 @@ def preprocess_giga_speech():
# Run data augmentation that needs to be done in the
# time domain.
if partition not in ["DEV", "TEST"]:
logging.info(
f"Speed perturb for {partition} with factors 0.9 and 1.1 "
"(Perturbing may take 8 minutes and saving may take 20 minutes)"
)
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
if args.perturb_speed:
logging.info(
f"Speed perturb for {partition} with factors 0.9 and 1.1 "
"(Perturbing may take 8 minutes and saving may take 20 minutes)"
)
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
logging.info(f"Saving to {raw_cuts_path}")
cut_set.to_file(raw_cuts_path)
@ -107,7 +129,8 @@ def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
preprocess_giga_speech()
args = get_args()
preprocess_giga_speech(args)
if __name__ == "__main__":

View File

@ -99,7 +99,14 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
exit 1;
fi
# Download XL, DEV and TEST sets by default.
lhotse download gigaspeech --subset auto --host tsinghua \
lhotse download gigaspeech --subset XL \
--subset L \
--subset M \
--subset S \
--subset XS \
--subset DEV \
--subset TEST \
--host tsinghua \
$dl_dir/password $dl_dir/GigaSpeech
fi
@ -118,7 +125,14 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
# We assume that you have downloaded the GigaSpeech corpus
# to $dl_dir/GigaSpeech
mkdir -p data/manifests
lhotse prepare gigaspeech --subset auto -j $nj \
lhotse prepare gigaspeech --subset XL \
--subset L \
--subset M \
--subset S \
--subset XS \
--subset DEV \
--subset TEST \
-j $nj \
$dl_dir/GigaSpeech data/manifests
fi
@ -139,8 +153,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute features for DEV and TEST subsets of GigaSpeech (may take 2 minutes)"
python3 ./local/compute_fbank_gigaspeech_dev_test.py
log "Stage 4: Compute features for L, M, S, XS, DEV and TEST subsets of GigaSpeech."
python3 ./local/compute_fbank_gigaspeech.py
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
@ -176,18 +190,9 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Prepare phone based lang"
log "Stage 9: Prepare transcript_words.txt and words.txt"
lang_dir=data/lang_phone
mkdir -p $lang_dir
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/lm/lexicon.txt |
sort | uniq > $lang_dir/lexicon.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
fi
if [ ! -f $lang_dir/transcript_words.txt ]; then
gunzip -c "data/manifests/gigaspeech_supervisions_XL.jsonl.gz" \
| jq '.text' \
@ -238,7 +243,21 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Prepare BPE based lang"
log "Stage 10: Prepare phone based lang"
lang_dir=data/lang_phone
mkdir -p $lang_dir
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/lm/lexicon.txt |
sort | uniq > $lang_dir/lexicon.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
fi
fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Prepare BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
@ -260,8 +279,8 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
done
fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Prepare bigram P"
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 12: Prepare bigram P"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
@ -291,8 +310,8 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
done
fi
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 12: Prepare G"
if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
log "Stage 13: Prepare G"
# We assume you have installed kaldilm, if not, please install
# it using: pip install kaldilm
@ -317,8 +336,8 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
fi
fi
if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
log "Stage 13: Compile HLG"
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
log "Stage 14: Compile HLG"
./local/compile_hlg.py --lang-dir data/lang_phone
for vocab_size in ${vocab_sizes[@]}; do

View File

@ -76,6 +76,7 @@ from beam_search import (
)
from gigaspeech_scoring import asr_text_post_processing
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,

View File

@ -105,7 +105,7 @@ class GigaSpeechAsrDataModule:
group.add_argument(
"--num-buckets",
type=int,
default=30,
default=100,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
@ -368,6 +368,8 @@ class GigaSpeechAsrDataModule:
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle=False,
)
logging.info("About to create dev dataloader")
@ -417,6 +419,7 @@ class GigaSpeechAsrDataModule:
logging.info(
f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode"
)
cuts_train = lhotse.combine(
lhotse.load_manifest_lazy(p) for p in sorted_filenames
)

View File

@ -88,7 +88,7 @@ import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule
from train import add_model_arguments, get_params, get_model
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,

View File

@ -51,7 +51,7 @@ from streaming_beam_search import (
)
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_model
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,

View File

@ -416,6 +416,17 @@ def get_parser():
help="Accumulate stats on activations, print them and exit.",
)
parser.add_argument(
"--scan-for-oom-batches",
type=str2bool,
default=False,
help="""
Whether to scan for oom batches before training, this is helpful for
finding the suitable max_duration, you only need to run it once.
Caution: a little time consuming.
""",
)
parser.add_argument(
"--inf-check",
type=str2bool,
@ -1171,9 +1182,16 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)
def remove_short_utt(c: Cut):
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
return T > 0
gigaspeech = GigaSpeechAsrDataModule(args)
train_cuts = gigaspeech.train_cuts()
train_cuts = train_cuts.filter(remove_short_utt)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
@ -1187,9 +1205,10 @@ def run(rank, world_size, args):
)
valid_cuts = gigaspeech.dev_cuts()
valid_cuts = valid_cuts.filter(remove_short_utt)
valid_dl = gigaspeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
if not params.print_diagnostics and params.scan_for_oom_batches:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,

View File

@ -0,0 +1,49 @@
# Results
## zipformer transducer model
This is a tiny general ASR model, which has around 3.3M parameters, see this PR https://github.com/k2-fsa/icefall/pull/1428 for how to train it and other details.
The modeling units are 500 BPEs trained on gigaspeech transcripts.
The positive test sets are from https://github.com/pkufool/open-commands and the negative test set is test set of gigaspeech (has 40 hours audios).
We put the whole pipeline in `run.sh` containing training, decoding and finetuning commands.
The models have been upload to [github](https://github.com/pkufool/keyword-spotting-models/releases/download/v0.11/icefall-kws-zipformer-gigaspeech-20240219.tar.gz).
Here is the results of a small test set which has 20 commands, we list the results of every commands, for
each metric there are two columns, one for the original model trained on gigaspeech XL subset, the other
for the finetune model finetuned on commands dataset.
Commands | FN in positive set |FN in positive set | Recall | Recall | FP in negative set | FP in negative set| False alarm (time / hour) 40 hours | False alarm (time / hour) 40 hours |
-- | -- | -- | -- | --| -- | -- | -- | --
  | original | finetune | original | finetune | original | finetune | original | finetune
All | 43/307 | 4/307 | 86% | 98.7% | 1 | 24 | 0.025 | 0.6
Lights on | 6/17 | 0/17 | 64.7% | 100% | 1 | 9 | 0.025 | 0.225
Heat up | 5/14 | 1/14 | 64.3% | 92.9% | 0 | 1 | 0 | 0.025
Volume down | 4/18 | 0/18 | 77.8% | 100% | 0 | 2 | 0 | 0.05
Volume max | 4/17 | 0/17 | 76.5% | 100% | 0 | 0 | 0 | 0
Volume mute | 4/16 | 0/16 | 75.0% | 100% | 0 | 0 | 0 | 0
Too quiet | 3/17 | 0/17 | 82.4% | 100% | 0 | 4 | 0 | 0.1
Lights off | 3/17 | 0/17 | 82.4% | 100% | 0 | 2 | 0 | 0.05
Play music | 2/14 | 0/14 | 85.7% | 100% | 0 | 0 | 0 | 0
Bring newspaper | 2/13 | 1/13 | 84.6% | 92.3% | 0 | 0 | 0 | 0
Heat down | 2/16 | 2/16 | 87.5% | 87.5% | 0 | 1 | 0 | 0.025
Volume up | 2/18 | 0/18 | 88.9% | 100% | 0 | 1 | 0 | 0.025
Too loud | 1/13 | 0/13 | 92.3% | 100% | 0 | 0 | 0 | 0
Resume music | 1/14 | 0/14 | 92.9% | 100% | 0 | 0 | 0 | 0
Bring shoes | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0
Switch language | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0
Pause music | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0
Bring socks | 1/12 | 0/12 | 91.7% | 100% | 0 | 0 | 0 | 0
Stop music | 0/15 | 0/15 | 100% | 100% | 0 | 0 | 0 | 0
Turn it up | 0/15 | 0/15 | 100% | 100% | 0 | 3 | 0 | 0.075
Turn it down | 0/16 | 0/16 | 100% | 100% | 0 | 1 | 0 | 0.025
This is the result of large test set, it has more than 200 commands, too many to list the details of each commands, so only an overall result here.
Commands | FN in positive set | FN in positive set | Recall | Recall | FP in negative set | FP in negative set | False alarm (time / hour)23 hours | False alarm (time / hour)23 hours
-- | -- | -- | -- | -- | -- | -- | -- | --
  | original | finetune | original | finetune | original | finetune | original | finetune
All | 622/3994 | 79/ 3994 | 83.6% | 97.9% | 18/19930 | 52/19930 | 0.45 | 1.3

85
egs/gigaspeech/KWS/prepare.sh Executable file
View File

@ -0,0 +1,85 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=15
stage=0
stop_stage=100
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Prepare gigaspeech dataset."
mkdir -p data/fbank
if [ ! -e data/fbank/.gigaspeech.done ]; then
pushd ../ASR
./prepare.sh --stage 0 --stop-stage 9
./prepare.sh --stage 11 --stop-stage 11
popd
pushd data/fbank
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_DEV.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_DEV.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_TEST.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_TEST.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_L.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_L.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_M.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_M.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_S.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_S.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_XS.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_XS.lca) .
ln -svf $(realpath ../ASR/data/fbank/XL_split) .
ln -svf $(realpath ../ASR/data/fbank/musan_cuts.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/musan_feats) .
popd
pushd data
ln -svf $(realpath ../ASR/data/lang_bpe_500) .
popd
touch data/fbank/.gigaspeech.done
else
log "Gigaspeech dataset already exists, skipping."
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare open commands dataset."
mkdir -p data/fbank
if [ ! -e data/fbank/.fluent_speech_commands.done ]; then
pushd data
git clone https://github.com/pkufool/open-commands.git
ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt
ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt
pushd open-commands
./script/prepare.sh --stage 2 --stop-stage 2
./script/prepare.sh --stage 6 --stop-stage 6
popd
popd
pushd data/fbank
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_large.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_large) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_small.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_small) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_valid.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_valid) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_train.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_train) .
popd
touch data/fbank/.fluent_speech_commands.done
else
log "Fluent speech commands dataset already exists, skipping."
fi
fi

197
egs/gigaspeech/KWS/run.sh Executable file
View File

@ -0,0 +1,197 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export PYTHONPATH=../../../:$PYTHONPATH
stage=0
stop_stage=100
. shared/parse_options.sh || exit 1
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Train a model."
if [ ! -e data/fbank/.gigaspeech.done ]; then
log "You need to run the prepare.sh first."
exit -1
fi
python ./zipformer/train.py \
--world-size 4 \
--exp-dir zipformer/exp \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--num-epochs 12 \
--lr-epochs 1.5 \
--use-fp16 1 \
--start-epoch 1 \
--subset XL \
--bpe-model data/lang_bpe_500/bpe.model \
--causal 1 \
--max-duration 1000
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Decode the model."
for t in small, large; do
python ./zipformer/decode.py \
--epoch 12 \
--avg 2 \
--exp-dir ./zipformer/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--causal 1 \
--chunk-size 16 \
--left-context-frames 64 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--test-set $t \
--keywords-score 1.0 \
--keywords-threshold 0.35 \
--keywords-file ./data/commands_${t}.txt \
--max-duration 3000
done
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Export the model."
python ./zipformer/export.py \
--epoch 12 \
--avg 2 \
--exp-dir ./zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 64 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128
python ./zipformer/export_onnx_streaming.py \
--exp-dir zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 12 \
--avg 2 \
--chunk-size 16 \
--left-context-frames 128 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--causal 1
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 2: Finetune the model"
# The following configuration of lr schedule should work well
# You may also tune the following parameters to adjust learning rate schedule
base_lr=0.0005
lr_epochs=100
lr_batches=100000
# We recommend to start from an averaged model
finetune_ckpt=zipformer/exp/pretrained.pt
./zipformer/finetune.py \
--world-size 4 \
--num-epochs 10 \
--start-epoch 1 \
--exp-dir zipformer/exp_finetune \
--bpe-model data/lang_bpe_500/bpe.model \
--use-fp16 1 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--causal 1 \
--base-lr $base_lr \
--lr-epochs $lr_epochs \
--lr-batches $lr_batches \
--finetune-ckpt $finetune_ckpt \
--max-duration 1500
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 1: Decode the finetuned model."
for t in small, large; do
python ./zipformer/decode.py \
--epoch 10 \
--avg 2 \
--exp-dir ./zipformer/exp_finetune \
--bpe-model data/lang_bpe_500/bpe.model \
--causal 1 \
--chunk-size 16 \
--left-context-frames 64 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--test-set $t \
--keywords-score 1.0 \
--keywords-threshold 0.35 \
--keywords-file ./data/commands_${t}.txt \
--max-duration 3000
done
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 2: Export the finetuned model."
python ./zipformer/export.py \
--epoch 10 \
--avg 2 \
--exp-dir ./zipformer/exp_finetune \
--tokens data/lang_bpe_500/tokens.txt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 64 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128
python ./zipformer/export_onnx_streaming.py \
--exp-dir zipformer/exp_finetune \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 10 \
--avg 2 \
--chunk-size 16 \
--left-context-frames 128 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--causal 1
fi

1
egs/gigaspeech/KWS/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared

View File

@ -0,0 +1,477 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2024 Xiaomi Corporation (Author: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import inspect
import logging
import re
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import lhotse
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import AudioSamples, OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class GigaSpeechAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
# GigaSpeech specific arguments
group.add_argument(
"--subset",
type=str,
default="XL",
help="Select the GigaSpeech subset (XS|S|M|L|XL)",
)
group.add_argument(
"--small-dev",
type=str2bool,
default=False,
help="Should we use only 1000 utterances for dev (speeds up training)",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info(f"About to get train {self.args.subset} cuts")
if self.args.subset == "XL":
filenames = glob.glob(
f"{self.args.manifest_dir}/XL_split/gigaspeech_cuts_XL.*.jsonl.gz"
)
pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames]
logging.info(
f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode"
)
cuts_train = lhotse.combine(
lhotse.load_manifest_lazy(p) for p in sorted_filenames
)
else:
path = (
self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz"
)
cuts_train = CutSet.from_jsonl_lazy(path)
return cuts_train
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest_lazy(
self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz"
)
if self.args.small_dev:
return cuts_valid.subset(first=1000)
else:
return cuts_valid
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
)
@lru_cache()
def fsc_train_cuts(self) -> CutSet:
logging.info("About to get fluent speech commands train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "fluent_speech_commands_cuts_train.jsonl.gz"
)
@lru_cache()
def fsc_valid_cuts(self) -> CutSet:
logging.info("About to get fluent speech commands valid cuts")
return load_manifest_lazy(
self.args.manifest_dir / "fluent_speech_commands_cuts_valid.jsonl.gz"
)
@lru_cache()
def fsc_test_small_cuts(self) -> CutSet:
logging.info("About to get fluent speech commands small test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "fluent_speech_commands_cuts_small.jsonl.gz"
)
@lru_cache()
def fsc_test_large_cuts(self) -> CutSet:
logging.info("About to get fluent speech commands large test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "fluent_speech_commands_cuts_large.jsonl.gz"
)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,687 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--keywords-file keywords.txt \
--beam-size 4
"""
import argparse
import logging
import math
import os
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule
from beam_search import keywords_search
from lhotse.cut import Cut
from train import add_model_arguments, get_model, get_params
from icefall import ContextGraph
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
@dataclass
class KwMetric:
TP: int = 0 # True positive
FN: int = 0 # False negative
FP: int = 0 # False positive
TN: int = 0 # True negative
FN_list: List[str] = field(default_factory=list)
FP_list: List[str] = field(default_factory=list)
TP_list: List[str] = field(default_factory=list)
def __str__(self) -> str:
return f"(TP:{self.TP}, FN:{self.FN}, FP:{self.FP}, TN:{self.TN})"
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--beam",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--keywords-file",
type=str,
help="File contains keywords.",
)
parser.add_argument(
"--test-set",
type=str,
default="small",
help="small or large",
)
parser.add_argument(
"--keywords-score",
type=float,
default=1.5,
help="""
The default boosting score (token level) for keywords. it will boost the
paths that match keywords to make them survive beam search.
""",
)
parser.add_argument(
"--keywords-threshold",
type=float,
default=0.35,
help="The default threshold (probability) to trigger the keyword.",
)
parser.add_argument(
"--num-tailing-blanks",
type=int,
default=1,
help="The number of tailing blanks should have after hitting one keyword.",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
keywords_graph: Optional[ContextGraph] = None,
) -> List[List[Tuple[str, Tuple[int, int]]]]:
"""Decode one batch and return the result in a list.
The length of the list equals to batch size, the i-th element contains the
triggered keywords for the i-th utterance in the given batch. The triggered
keywords are also a list, each of it contains a tuple of hitting keyword and
the corresponding start timestamps and end timestamps of the hitting keyword.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
keywords_graph:
The graph containing keywords.
Returns:
Return the decoding result. See above description for the format of
the returned list.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
if params.causal:
# this seems to cause insertions at the end of the utterance if used with zipformer.
pad_len = 30
feature_lens += pad_len
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, pad_len),
value=LOG_EPS,
)
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
ans_dict = keywords_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
keywords_graph=keywords_graph,
beam=params.beam,
num_tailing_blanks=params.num_tailing_blanks,
blank_penalty=params.blank_penalty,
)
hyps = []
for ans in ans_dict:
hyp = []
for hit in ans:
hyp.append((hit.phrase, (hit.timestamps[0], hit.timestamps[-1])))
hyps.append(hyp)
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
keywords_graph: ContextGraph,
keywords: Set[str],
test_only_keywords: bool,
) -> Tuple[List[Tuple[str, List[str], List[str]]], KwMetric]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
keywords_graph:
The graph containing keywords.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
log_interval = 50
results = []
metric = {"all": KwMetric()}
for k in keywords:
metric[k] = KwMetric()
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps = decode_one_batch(
params=params,
model=model,
sp=sp,
keywords_graph=keywords_graph,
batch=batch,
)
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_text = ref_text.upper()
ref_words = ref_text.split()
hyp_words = [x[0] for x in hyp_words]
# for computing WER
this_batch.append((cut_id, ref_words, " ".join(hyp_words).split()))
hyp_set = set(hyp_words) # each item is a keyword phrase
if len(hyp_words) > 1:
logging.warning(
f"Cut {cut_id} triggers more than one keywords : {hyp_words},"
f"please check the transcript to see if it really has more "
f"than one keywords, if so consider splitting this audio and"
f"keep only one keyword for each audio."
)
hyp_str = " | ".join(
hyp_words
) # The triggered keywords for this utterance.
TP = False
FP = False
for x in hyp_set:
assert x in keywords, x # can only trigger keywords
if (test_only_keywords and x == ref_text) or (
not test_only_keywords and x in ref_text
):
TP = True
metric[x].TP += 1
metric[x].TP_list.append(f"({ref_text} -> {x})")
if (test_only_keywords and x != ref_text) or (
not test_only_keywords and x not in ref_text
):
FP = True
metric[x].FP += 1
metric[x].FP_list.append(f"({ref_text} -> {x})")
if TP:
metric["all"].TP += 1
if FP:
metric["all"].FP += 1
TN = True # all keywords are true negative then the summery is true negative.
FN = False
for x in keywords:
if x not in ref_text and x not in hyp_set:
metric[x].TN += 1
continue
TN = False
if (test_only_keywords and x == ref_text) or (
not test_only_keywords and x in ref_text
):
fn = True
for y in hyp_set:
if (test_only_keywords and y == ref_text) or (
not test_only_keywords and y in ref_text
):
fn = False
break
if fn:
FN = True
metric[x].FN += 1
metric[x].FN_list.append(f"({ref_text} -> {hyp_str})")
if TN:
metric["all"].TN += 1
if FN:
metric["all"].FN += 1
results.extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results, metric
def save_results(
params: AttributeDict,
test_set_name: str,
results: List[Tuple[str, List[str], List[str]]],
metric: KwMetric,
):
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
metric_filename = params.res_dir / f"metric-{test_set_name}-{params.suffix}.txt"
with open(metric_filename, "w") as of:
width = 10
for key, item in sorted(
metric.items(), key=lambda x: (x[1].FP, x[1].FN), reverse=True
):
acc = (item.TP + item.TN) / (item.TP + item.TN + item.FP + item.FN)
precision = (
0.0 if (item.TP + item.FP) == 0 else item.TP / (item.TP + item.FP)
)
recall = 0.0 if (item.TP + item.FN) == 0 else item.TP / (item.TP + item.FN)
fpr = 0.0 if (item.FP + item.TN) == 0 else item.FP / (item.FP + item.TN)
s = f"{key}:\n"
s += f"\t{'TP':{width}}{'FP':{width}}{'FN':{width}}{'TN':{width}}\n"
s += f"\t{str(item.TP):{width}}{str(item.FP):{width}}{str(item.FN):{width}}{str(item.TN):{width}}\n"
s += f"\tAccuracy: {acc:.3f}\n"
s += f"\tPrecision: {precision:.3f}\n"
s += f"\tRecall(PPR): {recall:.3f}\n"
s += f"\tFPR: {fpr:.3f}\n"
s += f"\tF1: {0.0 if precision * recall == 0 else 2 * precision * recall / (precision + recall):.3f}\n"
if key != "all":
s += f"\tTP list: {' # '.join(item.TP_list)}\n"
s += f"\tFP list: {' # '.join(item.FP_list)}\n"
s += f"\tFN list: {' # '.join(item.FN_list)}\n"
of.write(s + "\n")
if key == "all":
logging.info(s)
of.write(f"\n\n{params.keywords_config}")
logging.info("Wrote metric stats to {}".format(metric_filename))
@torch.no_grad()
def main():
parser = get_parser()
GigaSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "kws"
params.suffix = params.test_set
if params.iter > 0:
params.suffix += f"-iter-{params.iter}-avg-{params.avg}"
else:
params.suffix += f"-epoch-{params.epoch}-avg-{params.avg}"
if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
params.suffix += f"-score-{params.keywords_score}"
params.suffix += f"-threshold-{params.keywords_threshold}"
params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}"
if params.blank_penalty != 0:
params.suffix += f"-blank-penalty-{params.blank_penalty}"
params.suffix += f"-keywords-{params.keywords_file.split('/')[-1]}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
phrases = []
token_ids = []
keywords_scores = []
keywords_thresholds = []
keywords_config = []
with open(params.keywords_file, "r") as f:
for line in f.readlines():
keywords_config.append(line)
score = 0
threshold = 0
keyword = []
words = line.strip().upper().split()
for word in words:
word = word.strip()
if word[0] == ":":
score = float(word[1:])
continue
if word[0] == "#":
threshold = float(word[1:])
continue
keyword.append(word)
keyword = " ".join(keyword)
phrases.append(keyword)
token_ids.append(sp.encode(keyword))
keywords_scores.append(score)
keywords_thresholds.append(threshold)
params.keywords_config = "".join(keywords_config)
keywords_graph = ContextGraph(
context_score=params.keywords_score, ac_threshold=params.keywords_threshold
)
keywords_graph.build(
token_ids=token_ids,
phrases=phrases,
scores=keywords_scores,
ac_thresholds=keywords_thresholds,
)
keywords = set(phrases)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
gigaspeech = GigaSpeechAsrDataModule(args)
test_cuts = gigaspeech.test_cuts()
test_dl = gigaspeech.test_dataloaders(test_cuts)
if params.test_set == "small":
test_fsc_small_cuts = gigaspeech.fsc_test_small_cuts()
test_fsc_small_dl = gigaspeech.test_dataloaders(test_fsc_small_cuts)
test_sets = ["small-fsc", "test"]
test_dls = [test_fsc_small_dl, test_dl]
else:
assert params.test_set == "large", params.test_set
test_fsc_large_cuts = gigaspeech.fsc_test_large_cuts()
test_fsc_large_dl = gigaspeech.test_dataloaders(test_fsc_large_cuts)
test_sets = ["large-fsc", "test"]
test_dls = [test_fsc_large_dl, test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results, metric = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
keywords_graph=keywords_graph,
keywords=keywords,
test_only_keywords="fsc" in test_set,
)
save_results(
params=params,
test_set_name=test_set,
results=results,
metric=metric,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer_stateless/encoder_interface.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export-onnx-streaming.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export.py

View File

@ -0,0 +1,643 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,
# Zengwei Yao,
# Yifan Yang,
# Daniel Povey)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
# For non-streaming model training:
./zipformer/finetune.py \
--world-size 8 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--max-duration 1000
# For streaming model training:
./zipformer/fintune.py \
--world-size 8 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--causal 1 \
--max-duration 1000
It supports training with:
- transducer loss (default), with `--use-transducer True --use-ctc False`
- ctc loss (not recommended), with `--use-transducer False --use-ctc True`
- transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
"""
import argparse
import copy
import logging
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import k2
import optim
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut, CutSet
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import AsrModel
from optim import Eden, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from train import (
add_model_arguments,
add_training_arguments,
compute_loss,
compute_validation_loss,
display_and_save_batch,
get_adjusted_batch_count,
get_model,
get_params,
load_checkpoint_if_available,
save_checkpoint,
scan_pessimistic_batches_for_oom,
set_batch_count,
)
from icefall import diagnostics
from icefall.checkpoint import remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx,
update_averaged_model,
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def add_finetune_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--use-mux",
type=str2bool,
default=False,
help="""
Whether to adapt. If true, we will mix 5% of the new data
with 95% of the original data to fine-tune.
""",
)
parser.add_argument(
"--init-modules",
type=str,
default=None,
help="""
Modules to be initialized. It matches all parameters starting with
a specific key. The keys are given with Comma seperated. If None,
all modules will be initialised. For example, if you only want to
initialise all parameters staring with "encoder", use "encoder";
if you want to initialise parameters starting with encoder or decoder,
use "encoder,joiner".
""",
)
parser.add_argument(
"--finetune-ckpt",
type=str,
default=None,
help="Fine-tuning from which checkpoint (a path to a .pt file)",
)
parser.add_argument(
"--continue-finetune",
type=str2bool,
default=False,
help="Continue finetuning or finetune from pre-trained model",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
add_training_arguments(parser)
add_model_arguments(parser)
add_finetune_arguments(parser)
return parser
def load_model_params(
ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True
):
"""Load model params from checkpoint
Args:
ckpt (str): Path to the checkpoint
model (nn.Module): model to be loaded
"""
logging.info(f"Loading checkpoint from {ckpt}")
checkpoint = torch.load(ckpt, map_location="cpu")
# if module list is empty, load the whole model from ckpt
if not init_modules:
if next(iter(checkpoint["model"])).startswith("module."):
logging.info("Loading checkpoint saved by DDP")
dst_state_dict = model.state_dict()
src_state_dict = checkpoint["model"]
for key in dst_state_dict.keys():
src_key = "{}.{}".format("module", key)
dst_state_dict[key] = src_state_dict.pop(src_key)
assert len(src_state_dict) == 0
model.load_state_dict(dst_state_dict, strict=strict)
else:
model.load_state_dict(checkpoint["model"], strict=strict)
else:
src_state_dict = checkpoint["model"]
dst_state_dict = model.state_dict()
for module in init_modules:
logging.info(f"Loading parameters starting with prefix {module}")
src_keys = [
k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
]
dst_keys = [
k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
]
assert set(src_keys) == set(dst_keys) # two sets should match exactly
for key in src_keys:
dst_state_dict[key] = src_state_dict.pop(key)
model.load_state_dict(dst_state_dict, strict=strict)
return None
def train_one_epoch(
params: AttributeDict,
model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
scheduler:
The learning rate scheduler, we call step() every step.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
model_avg:
The stored model averaged from the start of training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
model.train()
tot_loss = MetricsTracker()
saved_bad_model = False
def save_bad_model(suffix: str = ""):
save_checkpoint_impl(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model,
model_avg=model_avg,
params=params,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=0,
)
for batch_idx, batch in enumerate(train_dl):
if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params) + 100000)
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
# if params.continue_finetune:
# set_batch_count(model, params.batch_idx_train)
# else:
# set_batch_count(model, params.batch_idx_train + 100000)
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
except: # noqa
save_bad_model()
display_and_save_batch(batch, params=params, sp=sp)
raise
if params.print_diagnostics and batch_idx == 5:
return
if (
rank == 0
and params.batch_idx_train > 0
and params.batch_idx_train % params.average_period == 0
):
update_averaged_model(
params=params,
model_cur=model,
model_avg=model_avg,
)
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
model=model,
model_avg=model_avg,
params=params,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
rank=rank,
)
if batch_idx % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
save_bad_model(suffix="-first-warning")
saved_bad_model = True
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
sp=sp,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
if not params.use_transducer:
params.ctc_loss_scale = 1.0
logging.info(params)
logging.info("About to create model")
model = get_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
if params.continue_finetune:
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
else:
modules = params.init_modules.split(",") if params.init_modules else None
checkpoints = load_model_params(
ckpt=params.finetune_ckpt, model=model, init_modules=modules
)
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = ScaledAdam(
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
lr=params.base_lr, # should have no effect
clipping_scale=2.0,
)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=1.0)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"])
if (
checkpoints
and "scheduler" in checkpoints
and checkpoints["scheduler"] is not None
):
logging.info("Loading scheduler state dict")
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
if params.inf_check:
register_inf_check_hooks(model)
def remove_short_utt(c: Cut):
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
return T > 0
gigaspeech = GigaSpeechAsrDataModule(args)
if params.use_mux:
train_cuts = CutSet.mux(
gigaspeech.train_cuts(),
gigaspeech.fsc_train_cuts(),
weights=[0.9, 0.1],
)
else:
train_cuts = gigaspeech.fsc_train_cuts()
train_cuts = train_cuts.filter(remove_short_utt)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch
sampler_state_dict = checkpoints["sampler"]
else:
sampler_state_dict = None
train_dl = gigaspeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = gigaspeech.fsc_valid_cuts()
valid_cuts = valid_cuts.filter(remove_short_utt)
valid_dl = gigaspeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics and params.scan_for_oom_batches:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sp=sp,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
save_checkpoint(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
GigaSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()

View File

@ -0,0 +1 @@
../../ASR/zipformer/gigaspeech_scoring.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/optim.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/subsampling.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/zipformer.py

View File

@ -35,6 +35,7 @@ The following table lists the differences among them.
| `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) |
| `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty |
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe |
| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | It supports domain adaptation of Zipformer using parameter efficient adapters |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).

View File

@ -1526,7 +1526,7 @@ done
You may also decode using LODR + LM shallow fusion. This decoding method is proposed in <https://arxiv.org/pdf/2203.16776.pdf>.
It subtracts the internal language model score during shallow fusion, which is approximated by a bi-gram model. The bi-gram can be
generated by `generate-lm.sh`, or you may download it from <https://huggingface.co/marcoyang/librispeech_bigram>.
generated by `prepare_lm.sh` at stage 4, or you may download it from <https://huggingface.co/marcoyang/librispeech_bigram>.
The decoding command is as follows:

View File

@ -24,8 +24,7 @@ To run this file, do:
"""
import torch
from train import get_params, get_ctc_model
from train import get_ctc_model, get_params
def test_model():

View File

@ -59,9 +59,9 @@ import onnx
import torch
import torch.nn as nn
from decoder import Decoder
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from emformer import Emformer
from scaling_converter import convert_scaled_to_non_scaled
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,

View File

@ -39,7 +39,7 @@ Usage of this script:
import argparse
import logging
import math
from typing import List
from typing import List, Optional
import kaldifeat
import sentencepiece as spm
@ -47,7 +47,6 @@ import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
from torch.nn.utils.rnn import pad_sequence
from typing import Optional, List
def get_parser():

View File

@ -1,20 +0,0 @@
#!/usr/bin/env bash
lang_dir=data/lang_bpe_500
for ngram in 2 3 4 5; do
if [ ! -f $lang_dir/${ngram}gram.arpa ]; then
./shared/make_kn_lm.py \
-ngram-order ${ngram} \
-text $lang_dir/transcript_tokens.txt \
-lm $lang_dir/${ngram}gram.arpa
fi
if [ ! -f $lang_dir/${ngram}gram.fst.txt ]; then
python3 -m kaldilm \
--read-symbol-table="$lang_dir/tokens.txt" \
--disambig-symbol='#0' \
--max-order=${ngram} \
$lang_dir/${ngram}gram.arpa > $lang_dir/${ngram}gram.fst.txt
fi
done

View File

@ -28,6 +28,7 @@
import argparse
import shutil
from pathlib import Path
from typing import Dict
import sentencepiece as spm
@ -57,6 +58,18 @@ def get_args():
return parser.parse_args()
def generate_tokens(lang_dir: Path):
"""
Generate the tokens.txt from a bpe model.
"""
sp = spm.SentencePieceProcessor()
sp.load(str(lang_dir / "bpe.model"))
token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
with open(lang_dir / "tokens.txt", "w", encoding="utf-8") as f:
for sym, i in token2id.items():
f.write(f"{sym} {i}\n")
def main():
args = get_args()
vocab_size = args.vocab_size
@ -95,6 +108,8 @@ def main():
shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
generate_tokens(lang_dir)
if __name__ == "__main__":
main()

View File

@ -31,28 +31,28 @@ https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stat
"""
import argparse
import torch.multiprocessing as mp
import torch
import torch.nn as nn
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional, Tuple
from pathlib import Path
from typing import List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import AsrDataModule
from beam_search import (
fast_beam_search_one_best,
greedy_search_batch,
modified_beam_search,
)
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
from lhotse import CutSet, load_manifest_lazy
from lhotse.cut import Cut
from lhotse.supervision import AlignmentItem
from lhotse.serialization import SequentialJsonlWriter
from lhotse.supervision import AlignmentItem
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
def get_parser():

View File

@ -73,12 +73,11 @@ It will generate the following 3 files inside $repo/exp:
import argparse
import logging
import torch
from onnx_pretrained import OnnxModel
from icefall import is_module_available
import torch
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -6,8 +6,21 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=15
stage=-1
stop_stage=100
# run step 0 to step 5 by default
stage=0
stop_stage=5
# Note: This script just prepare the minimal requirements that needed by a
# transducer training with bpe units.
#
# If you want to use ngram or nnlm, please continue running prepare_lm.sh after
# you succeed running this script.
#
# This script also contains the steps to generate phone based units, but they
# will not run automatically, you can generate the phone based units by
# bash prepare.sh --stage -1 --stop-stage -1
# bash prepare.sh --stage 6 --stop-stage 6
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
@ -17,6 +30,18 @@ stop_stage=100
# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it.
# You can download them from https://www.openslr.org/12
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
#
# lm directory is not necessary for transducer training with bpe units, but it
# is needed by phone based modeling, you can download it by running
# bash prepare.sh --stage -1 --stop-stage -1
# then you can see the following files in the directory.
# - $dl_dir/lm
# This directory contains the following files downloaded from
# http://www.openslr.org/resources/11
@ -28,14 +53,7 @@ stop_stage=100
# - librispeech-vocab.txt
# - librispeech-lexicon.txt
# - librispeech-lm-norm.txt.gz
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
@ -60,6 +78,8 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "Running prepare.sh"
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
@ -159,13 +179,49 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare phone based lang"
log "Stage 5: Prepare BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
mkdir -p $lang_dir
if [ ! -f $lang_dir/transcript_words.txt ]; then
log "Generate data for BPE training"
files=$(
find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt"
find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt"
)
for f in ${files[@]}; do
cat $f | cut -d " " -f 2-
done > $lang_dir/transcript_words.txt
fi
if [ ! -f $lang_dir/bpe.model ]; then
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/transcript_words.txt
fi
done
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare phone based lang"
lang_dir=data/lang_phone
mkdir -p $lang_dir
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/lm/librispeech-lexicon.txt |
sort | uniq > $lang_dir/lexicon.txt
if [ ! -f $dl_dir/lm/librispeech-lexicon.txt ]; then
log "No lexicon file in $dl_dir/lm, please run :"
log "prepare.sh --stage -1 --stop-stage -1"
exit -1
fi
if [ ! -f $lang_dir/lexicon.txt ]; then
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/lm/librispeech-lexicon.txt |
sort | uniq > $lang_dir/lexicon.txt
fi
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
@ -187,253 +243,3 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
$lang_dir/L_disambig.fst
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
mkdir -p $lang_dir
# We reuse words.txt from phone based lexicon
# so that the two can share G.pt later.
cp data/lang_phone/words.txt $lang_dir
if [ ! -f $lang_dir/transcript_words.txt ]; then
log "Generate data for BPE training"
files=$(
find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt"
find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt"
)
for f in ${files[@]}; do
cat $f | cut -d " " -f 2-
done > $lang_dir/transcript_words.txt
fi
if [ ! -f $lang_dir/bpe.model ]; then
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/transcript_words.txt
fi
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py --lang-dir $lang_dir
log "Validating $lang_dir/lexicon.txt"
./local/validate_bpe_lexicon.py \
--lexicon $lang_dir/lexicon.txt \
--bpe-model $lang_dir/bpe.model
fi
if [ ! -f $lang_dir/L.fst ]; then
log "Converting L.pt to L.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L.pt \
$lang_dir/L.fst
fi
if [ ! -f $lang_dir/L_disambig.fst ]; then
log "Converting L_disambig.pt to L_disambig.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L_disambig.pt \
$lang_dir/L_disambig.fst
fi
done
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Prepare bigram token-level P for MMI training"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
if [ ! -f $lang_dir/transcript_tokens.txt ]; then
./local/convert_transcript_words_to_tokens.py \
--lexicon $lang_dir/lexicon.txt \
--transcript $lang_dir/transcript_words.txt \
--oov "<UNK>" \
> $lang_dir/transcript_tokens.txt
fi
if [ ! -f $lang_dir/P.arpa ]; then
./shared/make_kn_lm.py \
-ngram-order 2 \
-text $lang_dir/transcript_tokens.txt \
-lm $lang_dir/P.arpa
fi
if [ ! -f $lang_dir/P.fst.txt ]; then
python3 -m kaldilm \
--read-symbol-table="$lang_dir/tokens.txt" \
--disambig-symbol='#0' \
--max-order=2 \
$lang_dir/P.arpa > $lang_dir/P.fst.txt
fi
done
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Prepare G"
# We assume you have installed kaldilm, if not, please install
# it using: pip install kaldilm
mkdir -p data/lm
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
# It is used in building HLG
python3 -m kaldilm \
--read-symbol-table="data/lang_phone/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
$dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
fi
if [ ! -f data/lm/G_4_gram.fst.txt ]; then
# It is used for LM rescoring
python3 -m kaldilm \
--read-symbol-table="data/lang_phone/words.txt" \
--disambig-symbol='#0' \
--max-order=4 \
$dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt
fi
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
if [ ! -f $lang_dir/HL.fst ]; then
./local/prepare_lang_fst.py \
--lang-dir $lang_dir \
--ngram-G ./data/lm/G_3_gram.fst.txt
fi
done
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Compile HLG"
./local/compile_hlg.py --lang-dir data/lang_phone
# Note If ./local/compile_hlg.py throws OOM,
# please switch to the following command
#
# ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
./local/compile_hlg.py --lang-dir $lang_dir
# Note If ./local/compile_hlg.py throws OOM,
# please switch to the following command
#
# ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
done
fi
# Compile LG for RNN-T fast_beam_search decoding
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Compile LG"
./local/compile_lg.py --lang-dir data/lang_phone
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
./local/compile_lg.py --lang-dir $lang_dir
done
fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Generate LM training data"
for vocab_size in ${vocab_sizes[@]}; do
log "Processing vocab_size == ${vocab_size}"
lang_dir=data/lang_bpe_${vocab_size}
out_dir=data/lm_training_bpe_${vocab_size}
mkdir -p $out_dir
./local/prepare_lm_training_data.py \
--bpe-model $lang_dir/bpe.model \
--lm-data $dl_dir/lm/librispeech-lm-norm.txt \
--lm-archive $out_dir/lm_data.pt
done
fi
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 12: Generate LM validation data"
for vocab_size in ${vocab_sizes[@]}; do
log "Processing vocab_size == ${vocab_size}"
out_dir=data/lm_training_bpe_${vocab_size}
mkdir -p $out_dir
if [ ! -f $out_dir/valid.txt ]; then
files=$(
find "$dl_dir/LibriSpeech/dev-clean" -name "*.trans.txt"
find "$dl_dir/LibriSpeech/dev-other" -name "*.trans.txt"
)
for f in ${files[@]}; do
cat $f | cut -d " " -f 2-
done > $out_dir/valid.txt
fi
lang_dir=data/lang_bpe_${vocab_size}
./local/prepare_lm_training_data.py \
--bpe-model $lang_dir/bpe.model \
--lm-data $out_dir/valid.txt \
--lm-archive $out_dir/lm_data-valid.pt
done
fi
if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
log "Stage 13: Generate LM test data"
for vocab_size in ${vocab_sizes[@]}; do
log "Processing vocab_size == ${vocab_size}"
out_dir=data/lm_training_bpe_${vocab_size}
mkdir -p $out_dir
if [ ! -f $out_dir/test.txt ]; then
files=$(
find "$dl_dir/LibriSpeech/test-clean" -name "*.trans.txt"
find "$dl_dir/LibriSpeech/test-other" -name "*.trans.txt"
)
for f in ${files[@]}; do
cat $f | cut -d " " -f 2-
done > $out_dir/test.txt
fi
lang_dir=data/lang_bpe_${vocab_size}
./local/prepare_lm_training_data.py \
--bpe-model $lang_dir/bpe.model \
--lm-data $out_dir/test.txt \
--lm-archive $out_dir/lm_data-test.pt
done
fi
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
log "Stage 14: Sort LM training data"
# Sort LM training data by sentence length in descending order
# for ease of training.
#
# Sentence length equals to the number of BPE tokens
# in a sentence.
for vocab_size in ${vocab_sizes[@]}; do
out_dir=data/lm_training_bpe_${vocab_size}
mkdir -p $out_dir
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data.pt \
--out-lm-data $out_dir/sorted_lm_data.pt \
--out-statistics $out_dir/statistics.txt
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data-valid.pt \
--out-lm-data $out_dir/sorted_lm_data-valid.pt \
--out-statistics $out_dir/statistics-valid.txt
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data-test.pt \
--out-lm-data $out_dir/sorted_lm_data-test.pt \
--out-statistics $out_dir/statistics-test.txt
done
fi

262
egs/librispeech/ASR/prepare_lm.sh Executable file
View File

@ -0,0 +1,262 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
# This script generate Ngram LM / NNLM and related files that needed by decoding.
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/lm
# This directory contains the following files downloaded from
# http://www.openslr.org/resources/11
#
# - 3-gram.pruned.1e-7.arpa.gz
# - 3-gram.pruned.1e-7.arpa
# - 4-gram.arpa.gz
# - 4-gram.arpa
# - librispeech-vocab.txt
# - librispeech-lexicon.txt
# - librispeech-lm-norm.txt.gz
#
. prepare.sh --stage -1 --stop-stage 6 || exit 1
log "Running prepare_lm.sh"
stage=0
stop_stage=100
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Prepare BPE based lexicon."
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
# We reuse words.txt from phone based lexicon
# so that the two can share G.pt later.
cp data/lang_phone/words.txt $lang_dir
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py --lang-dir $lang_dir
log "Validating $lang_dir/lexicon.txt"
./local/validate_bpe_lexicon.py \
--lexicon $lang_dir/lexicon.txt \
--bpe-model $lang_dir/bpe.model
fi
if [ ! -f $lang_dir/L.fst ]; then
log "Converting L.pt to L.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L.pt \
$lang_dir/L.fst
fi
if [ ! -f $lang_dir/L_disambig.fst ]; then
log "Converting L_disambig.pt to L_disambig.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L_disambig.pt \
$lang_dir/L_disambig.fst
fi
done
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare word level G"
# We assume you have installed kaldilm, if not, please install
# it using: pip install kaldilm
mkdir -p data/lm
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
# It is used in building HLG
python3 -m kaldilm \
--read-symbol-table="data/lang_phone/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
$dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
fi
if [ ! -f data/lm/G_4_gram.fst.txt ]; then
# It is used for LM rescoring
python3 -m kaldilm \
--read-symbol-table="data/lang_phone/words.txt" \
--disambig-symbol='#0' \
--max-order=4 \
$dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt
fi
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
if [ ! -f $lang_dir/HL.fst ]; then
./local/prepare_lang_fst.py \
--lang-dir $lang_dir \
--ngram-G ./data/lm/G_3_gram.fst.txt
fi
done
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Compile HLG"
./local/compile_hlg.py --lang-dir data/lang_phone
# Note If ./local/compile_hlg.py throws OOM,
# please switch to the following command
#
# ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
./local/compile_hlg.py --lang-dir $lang_dir
# Note If ./local/compile_hlg.py throws OOM,
# please switch to the following command
#
# ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
done
fi
# Compile LG for RNN-T fast_beam_search decoding
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compile LG"
./local/compile_lg.py --lang-dir data/lang_phone
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
./local/compile_lg.py --lang-dir $lang_dir
done
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Prepare token level ngram G"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
if [ ! -f $lang_dir/transcript_tokens.txt ]; then
./local/convert_transcript_words_to_tokens.py \
--lexicon $lang_dir/lexicon.txt \
--transcript $lang_dir/transcript_words.txt \
--oov "<UNK>" \
> $lang_dir/transcript_tokens.txt
fi
for ngram in 2 3 4 5; do
if [ ! -f $lang_dir/${ngram}gram.arpa ]; then
./shared/make_kn_lm.py \
-ngram-order ${ngram} \
-text $lang_dir/transcript_tokens.txt \
-lm $lang_dir/${ngram}gram.arpa
fi
if [ ! -f $lang_dir/${ngram}gram.fst.txt ]; then
python3 -m kaldilm \
--read-symbol-table="$lang_dir/tokens.txt" \
--disambig-symbol='#0' \
--max-order=${ngram} \
$lang_dir/${ngram}gram.arpa > $lang_dir/${ngram}gram.fst.txt
fi
done
done
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Generate NNLM training data"
for vocab_size in ${vocab_sizes[@]}; do
log "Processing vocab_size == ${vocab_size}"
lang_dir=data/lang_bpe_${vocab_size}
out_dir=data/lm_training_bpe_${vocab_size}
mkdir -p $out_dir
./local/prepare_lm_training_data.py \
--bpe-model $lang_dir/bpe.model \
--lm-data $dl_dir/lm/librispeech-lm-norm.txt \
--lm-archive $out_dir/lm_data.pt
done
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Generate NNLM validation data"
for vocab_size in ${vocab_sizes[@]}; do
log "Processing vocab_size == ${vocab_size}"
out_dir=data/lm_training_bpe_${vocab_size}
mkdir -p $out_dir
if [ ! -f $out_dir/valid.txt ]; then
files=$(
find "$dl_dir/LibriSpeech/dev-clean" -name "*.trans.txt"
find "$dl_dir/LibriSpeech/dev-other" -name "*.trans.txt"
)
for f in ${files[@]}; do
cat $f | cut -d " " -f 2-
done > $out_dir/valid.txt
fi
lang_dir=data/lang_bpe_${vocab_size}
./local/prepare_lm_training_data.py \
--bpe-model $lang_dir/bpe.model \
--lm-data $out_dir/valid.txt \
--lm-archive $out_dir/lm_data-valid.pt
done
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Generate NNLM test data"
for vocab_size in ${vocab_sizes[@]}; do
log "Processing vocab_size == ${vocab_size}"
out_dir=data/lm_training_bpe_${vocab_size}
mkdir -p $out_dir
if [ ! -f $out_dir/test.txt ]; then
files=$(
find "$dl_dir/LibriSpeech/test-clean" -name "*.trans.txt"
find "$dl_dir/LibriSpeech/test-other" -name "*.trans.txt"
)
for f in ${files[@]}; do
cat $f | cut -d " " -f 2-
done > $out_dir/test.txt
fi
lang_dir=data/lang_bpe_${vocab_size}
./local/prepare_lm_training_data.py \
--bpe-model $lang_dir/bpe.model \
--lm-data $out_dir/test.txt \
--lm-archive $out_dir/lm_data-test.pt
done
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Sort NNLM training data"
# Sort LM training data by sentence length in descending order
# for ease of training.
#
# Sentence length equals to the number of BPE tokens
# in a sentence.
for vocab_size in ${vocab_sizes[@]}; do
out_dir=data/lm_training_bpe_${vocab_size}
mkdir -p $out_dir
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data.pt \
--out-lm-data $out_dir/sorted_lm_data.pt \
--out-statistics $out_dir/statistics.txt
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data-valid.pt \
--out-lm-data $out_dir/sorted_lm_data-valid.pt \
--out-statistics $out_dir/statistics-valid.txt
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data-test.pt \
--out-lm-data $out_dir/sorted_lm_data-test.pt \
--out-statistics $out_dir/statistics-test.txt
done
fi

View File

@ -0,0 +1,45 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
. prepare.sh --stage -1 --stop-stage 6 || exit 1
log "Running prepare_mmi.sh"
stage=0
stop_stage=100
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Prepare bigram token-level P for MMI training"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
if [ ! -f $lang_dir/transcript_tokens.txt ]; then
./local/convert_transcript_words_to_tokens.py \
--lexicon $lang_dir/lexicon.txt \
--transcript $lang_dir/transcript_words.txt \
--oov "<UNK>" \
> $lang_dir/transcript_tokens.txt
fi
if [ ! -f $lang_dir/P.arpa ]; then
./shared/make_kn_lm.py \
-ngram-order 2 \
-text $lang_dir/transcript_tokens.txt \
-lm $lang_dir/P.arpa
fi
if [ ! -f $lang_dir/P.fst.txt ]; then
python3 -m kaldilm \
--read-symbol-table="$lang_dir/tokens.txt" \
--disambig-symbol='#0' \
--max-order=2 \
$lang_dir/P.arpa > $lang_dir/P.fst.txt
fi
done
fi

View File

@ -22,11 +22,12 @@ Usage: ./pruned_transducer_stateless/my_profile.py
import argparse
import logging
import sentencepiece as spm
import torch
from train import add_model_arguments, get_encoder_model, get_params
from icefall.profiler import get_model_profile
from train import get_encoder_model, add_model_arguments, get_params
def get_parser():

View File

@ -75,8 +75,7 @@ import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from onnx_pretrained import greedy_search, OnnxModel
from onnx_pretrained import OnnxModel, greedy_search
from icefall.utils import setup_logger, store_transcripts, write_error_stats

View File

@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union
@ -31,6 +32,7 @@ from icefall.rnn_lm.model import RnnLmModel
from icefall.transformer_lm.model import TransformerLM
from icefall.utils import (
DecodingResults,
KeywordResult,
add_eos,
add_sos,
get_texts,
@ -789,6 +791,8 @@ class Hypothesis:
# It contains only one entry.
log_prob: torch.Tensor
ac_probs: Optional[List[float]] = None
# timestamp[i] is the frame index after subsampling
# on which ys[i] is decoded
timestamp: List[int] = field(default_factory=list)
@ -805,6 +809,8 @@ class Hypothesis:
# Context graph state
context_state: Optional[ContextState] = None
num_tailing_blanks: int = 0
@property
def key(self) -> str:
"""Return a string representation of self.ys"""
@ -953,6 +959,241 @@ def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
return ans
def keywords_search(
model: nn.Module,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
keywords_graph: ContextGraph,
beam: int = 4,
num_tailing_blanks: int = 0,
blank_penalty: float = 0,
) -> List[List[KeywordResult]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
keywords_graph:
A instance of ContextGraph containing keywords and their configurations.
beam:
Number of active paths during the beam search.
num_tailing_blanks:
The number of tailing blanks a keyword should be followed, this is for the
scenario that a keyword will be the prefix of another. In most cases, you
can just set it to 0.
blank_penalty:
The score used to penalize blank probability.
Returns:
Return a list of list of KeywordResult.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
assert keywords_graph is not None
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[-1] * (context_size - 1) + [blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
context_state=keywords_graph.root,
timestamp=[],
ac_probs=[],
)
)
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
finalized_B = []
sorted_ans = [[] for _ in range(N)]
for t, batch_size in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
if blank_penalty != 0:
logits[:, 0] -= blank_penalty
probs = logits.softmax(dim=-1) # (num_hyps, vocab_size)
log_probs = probs.log()
probs = probs.reshape(-1)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
ragged_probs = k2.RaggedTensor(shape=log_probs_shape, value=probs)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
hyp_probs = ragged_probs[i].tolist()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
new_ac_probs = hyp.ac_probs[:]
context_score = 0
new_context_state = hyp.context_state
new_num_tailing_blanks = hyp.num_tailing_blanks + 1
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
new_timestamp.append(t)
new_ac_probs.append(hyp_probs[topk_indexes[k]])
(
context_score,
new_context_state,
_,
) = keywords_graph.forward_one_step(hyp.context_state, new_token)
new_num_tailing_blanks = 0
if new_context_state.token == -1: # root
new_ys[-context_size:] = [-1] * (context_size - 1) + [blank_id]
new_log_prob = topk_log_probs[k] + context_score
new_hyp = Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
timestamp=new_timestamp,
ac_probs=new_ac_probs,
context_state=new_context_state,
num_tailing_blanks=new_num_tailing_blanks,
)
B[i].add(new_hyp)
top_hyp = B[i].get_most_probable(length_norm=True)
matched, matched_state = keywords_graph.is_matched(top_hyp.context_state)
if matched:
ac_prob = (
sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level
)
if (
matched
and top_hyp.num_tailing_blanks > num_tailing_blanks
and ac_prob >= matched_state.ac_threshold
):
keyword = KeywordResult(
hyps=top_hyp.ys[-matched_state.level :],
timestamps=top_hyp.timestamp[-matched_state.level :],
phrase=matched_state.phrase,
)
sorted_ans[i].append(keyword)
B[i] = HypothesisList()
B[i].add(
Hypothesis(
ys=[-1] * (context_size - 1) + [blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
context_state=keywords_graph.root,
timestamp=[],
ac_probs=[],
)
)
B = B + finalized_B
for i, hyps in enumerate(B):
top_hyp = hyps.get_most_probable(length_norm=True)
matched, matched_state = keywords_graph.is_matched(top_hyp.context_state)
if matched:
ac_prob = (
sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level
)
if matched and ac_prob >= matched_state.ac_threshold:
keyword = KeywordResult(
hyps=top_hyp.ys[-matched_state.level :],
timestamps=top_hyp.timestamp[-matched_state.level :],
phrase=matched_state.phrase,
)
sorted_ans[i].append(keyword)
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
def modified_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,

View File

@ -78,10 +78,10 @@ It will generate the following 3 files inside $repo/exp:
import argparse
import logging
from icefall import is_module_available
import torch
from onnx_pretrained import OnnxModel
import torch
from icefall import is_module_available
def get_parser():

View File

@ -76,8 +76,7 @@ import torch
import torch.nn as nn
from asr_datamodule import AsrDataModule
from librispeech import LibriSpeech
from onnx_pretrained import greedy_search, OnnxModel
from onnx_pretrained import OnnxModel, greedy_search
from icefall.utils import setup_logger, store_transcripts, write_error_stats

View File

@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless4/my_profile.py
import argparse
import logging
from typing import Tuple
import sentencepiece as spm
import torch
from typing import Tuple
from scaling import BasicNorm, DoubleSwish
from torch import Tensor, nn
from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params
from icefall.profiler import get_model_profile
from scaling import BasicNorm, DoubleSwish
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
def get_parser():

Some files were not shown because too many files have changed in this diff Show More