Merge branch 'k2-fsa:master' into fisher_swbd

This commit is contained in:
Nagendra Goel 2022-09-29 13:23:00 -04:00 committed by GitHub
commit 12a755107f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
109 changed files with 11143 additions and 167 deletions

10
.flake8
View File

@ -9,7 +9,7 @@ per-file-ignores =
egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
egs/*/ASR/*/optim.py: E501,
egs/*/ASR/*/scaling.py: E501,
egs/librispeech/ASR/lstm_transducer_stateless/*.py: E501, E203
egs/librispeech/ASR/lstm_transducer_stateless*/*.py: E501, E203
egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203
egs/librispeech/ASR/conformer_ctc2/*py: E501,
egs/librispeech/ASR/RESULTS.md: E999,
@ -22,3 +22,11 @@ exclude =
**/data/**,
icefall/shared/make_kn_lm.py,
icefall/__init__.py
ignore =
# E203 white space before ":"
E203,
# W503 line break before binary operator
W503,
# E226 missing whitespace around arithmetic operator
E226,

View File

@ -0,0 +1,160 @@
#!/usr/bin/env bash
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
ln -s pretrained-iter-468000-avg-16.pt pretrained.pt
ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
popd
log "Install ncnn and pnnx"
# We are using a modified ncnn here. Will try to merge it to the official repo
# of ncnn
git clone https://github.com/csukuangfj/ncnn
pushd ncnn
git submodule init
git submodule update python/pybind11
python3 setup.py bdist_wheel
ls -lh dist/
pip install dist/*.whl
cd tools/pnnx
mkdir build
cd build
cmake ..
make -j4 pnnx
./src/pnnx || echo "pass"
popd
log "Test exporting to pnnx format"
./lstm_transducer_stateless2/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
--pnnx 1
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
./lstm_transducer_stateless2/ncnn-decode.py \
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
$repo/test_wavs/1089-134686-0001.wav
./lstm_transducer_stateless2/streaming-ncnn-decode.py \
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
$repo/test_wavs/1089-134686-0001.wav
log "Test exporting with torch.jit.trace()"
./lstm_transducer_stateless2/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
--jit-trace 1
log "Decode with models exported by torch.jit.trace()"
./lstm_transducer_stateless2/jit_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo/exp/encoder_jit_trace.pt \
--decoder-model-filename $repo/exp/decoder_jit_trace.pt \
--joiner-model-filename $repo/exp/joiner_jit_trace.pt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"
./lstm_transducer_stateless2/pretrained.py \
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
for method in modified_beam_search beam_search fast_beam_search; do
log "$method"
./lstm_transducer_stateless2/pretrained.py \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"ncnn" ]]; then
mkdir -p lstm_transducer_stateless2/exp
ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh lstm_transducer_stateless2/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method"
./lstm_transducer_stateless2/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--use-averaged-model 0 \
--max-duration $max_duration \
--exp-dir lstm_transducer_stateless2/exp
done
rm lstm_transducer_stateless2/exp/*.pt
fi

View File

@ -69,7 +69,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -68,7 +68,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -68,7 +68,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -68,7 +68,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -68,7 +68,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -0,0 +1,136 @@
name: run-librispeech-lstm-transducer-2022-09-03
on:
push:
branches:
- master
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
jobs:
run_librispeech_pruned_transducer_stateless3_2022_05_13:
if: github.event.label.name == 'ncnn' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.8]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
pip install --no-binary protobuf protobuf
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/install-kaldifeat.sh
- name: Cache LibriSpeech test-clean and test-other datasets
id: libri-test-clean-and-test-other-data
uses: actions/cache@v2
with:
path: |
~/tmp/download
key: cache-libri-test-clean-and-test-other
- name: Download LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
- name: Prepare manifests for LibriSpeech test-clean and test-other
shell: bash
run: |
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
- name: Cache LibriSpeech test-clean and test-other fbank features
id: libri-test-clean-and-test-other-fbank
uses: actions/cache@v2
with:
path: |
~/tmp/fbank-libri
key: cache-libri-fbank-test-clean-and-test-other-v2
- name: Compute fbank for LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
- name: Display decoding results for lstm_transducer_stateless2
if: github.event_name == 'schedule' || github.event.label.name == 'ncnn'
shell: bash
run: |
cd egs/librispeech/ASR
tree lstm_transducer_stateless2/exp
cd lstm_transducer_stateless2/exp
echo "===greedy search==="
find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search==="
find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for lstm_transducer_stateless2
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'ncnn'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03
path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/

View File

@ -68,7 +68,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -68,7 +68,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -68,7 +68,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -58,7 +58,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -67,7 +67,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -67,7 +67,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -58,7 +58,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -58,7 +58,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -67,7 +67,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -58,7 +58,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

2
.gitignore vendored
View File

@ -11,3 +11,5 @@ log
*.bak
*-bak
*bak.py
*.param
*.bin

View File

@ -1,24 +1,114 @@
# icefall dockerfile
We provide a dockerfile for some users, the configuration of dockerfile is : Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8-python3.8. You can use the dockerfile by following the steps:
2 sets of configuration are provided - (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8, and (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8.
## Building images locally
If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8.
Otherwise, since the older PyTorch images are not updated with the [apt-key rotation by NVIDIA](https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key), you have to go for case (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. Ensure that your NVDIA driver supports at least CUDA 11.0.
You can check the highest CUDA version within your NVIDIA driver's support with the `nvidia-smi` command below. In this example, the highest CUDA version is 11.0, i.e. case (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8.
```bash
cd docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8
docker build -t icefall/pytorch1.7.1:latest -f ./Dockerfile ./
$ nvidia-smi
Tue Sep 20 00:26:13 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.119.03 Driver Version: 450.119.03 CUDA Version: 11.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 TITAN RTX On | 00000000:03:00.0 Off | N/A |
| 41% 31C P8 4W / 280W | 16MiB / 24219MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 1 TITAN RTX On | 00000000:04:00.0 Off | N/A |
| 41% 30C P8 11W / 280W | 6MiB / 24220MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 2085 G /usr/lib/xorg/Xorg 9MiB |
| 0 N/A N/A 2240 G /usr/bin/gnome-shell 4MiB |
| 1 N/A N/A 2085 G /usr/lib/xorg/Xorg 4MiB |
+-----------------------------------------------------------------------------+
```
## Using built images
Sample usage of the GPU based images:
## Building images locally
If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly.
For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details.
```dockerfile
ENV http_proxy=http://aaa.bb.cc.net:8080 \
https_proxy=http://aaa.bb.cc.net:8080
```
Then, proceed with these commands.
### If you are case (a), i.e. your NVIDIA driver supports CUDA version >= 11.3:
```bash
cd docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8
docker build -t icefall/pytorch1.12.1 .
```
### If you are case (b), i.e. your NVIDIA driver can only support CUDA versions 11.0 <= x < 11.3:
```bash
cd docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8
docker build -t icefall/pytorch1.7.1 .
```
## Running your built local image
Sample usage of the GPU based images. These commands are written with case (a) in mind, so please make the necessary changes to your image name if you are case (b).
Note: use [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) to run the GPU images.
```bash
docker run -it --runtime=nvidia --name=icefall_username --gpus all icefall/pytorch1.7.1:latest
docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all icefall/pytorch1.12.1
```
Sample usage of the CPU based images:
### Tips:
1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/docker}:{/path/in/host/machine}`.
2. Also, if your environment requires a proxy, this would be a good time to add it in too: `-e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080`.
Overall, your docker run command should look like this.
```bash
docker run -it icefall/pytorch1.7.1:latest /bin/bash
```
docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/docker}:{/path/in/host/machine} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1
```
You can explore more docker run options [here](https://docs.docker.com/engine/reference/commandline/run/) to suit your environment.
### Linking to icefall in your host machine
If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container.
Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine.
Warning: Check that the icefall in your host machine is visible from within your container before proceeding to the commands below.
Use these commands once you are inside the container.
```bash
rm -r /workspace/icefall
ln -s {/path/in/docker/to/icefall} /workspace/icefall
```
## Starting another session in the same running container.
```bash
docker exec -it icefall /bin/bash
```
## Restarting a killed container that has been run before.
```bash
docker start -ai icefall
```
## Sample usage of the CPU based images:
```bash
docker run -it icefall /bin/bash
```

View File

@ -0,0 +1,72 @@
FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel
# ENV http_proxy=http://aaa.bbb.cc.net:8080 \
# https_proxy=http://aaa.bbb.cc.net:8080
# install normal source
RUN apt-get update && \
apt-get install -y --no-install-recommends \
g++ \
make \
automake \
autoconf \
bzip2 \
unzip \
wget \
sox \
libtool \
git \
subversion \
zlib1g-dev \
gfortran \
ca-certificates \
patch \
ffmpeg \
valgrind \
libssl-dev \
vim \
curl
# cmake
RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
cd /opt && \
tar -zxvf cmake-3.18.0.tar.gz && \
cd cmake-3.18.0 && \
./bootstrap && \
make && \
make install && \
rm -rf cmake-3.18.0.tar.gz && \
find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
cd -
# flac
RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \
cd /opt && \
xz -d flac-1.3.2.tar.xz && \
tar -xvf flac-1.3.2.tar && \
cd flac-1.3.2 && \
./configure && \
make && make install && \
rm -rf flac-1.3.2.tar && \
find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
cd -
RUN pip install kaldiio graphviz && \
conda install -y -c pytorch torchaudio
#install k2 from source
RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
cd /opt/k2 && \
python3 setup.py install && \
cd -
# install lhotse
RUN pip install git+https://github.com/lhotse-speech/lhotse
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -1,7 +1,13 @@
FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel
# install normal source
# ENV http_proxy=http://aaa.bbb.cc.net:8080 \
# https_proxy=http://aaa.bbb.cc.net:8080
RUN rm /etc/apt/sources.list.d/cuda.list && \
rm /etc/apt/sources.list.d/nvidia-ml.list && \
apt-key del 7fa2af80
# install normal source
RUN apt-get update && \
apt-get install -y --no-install-recommends \
g++ \
@ -21,20 +27,25 @@ RUN apt-get update && \
patch \
ffmpeg \
valgrind \
libssl-dev \
vim && \
rm -rf /var/lib/apt/lists/*
libssl-dev \
vim \
curl
RUN mv /opt/conda/lib/libcufft.so.10 /opt/libcufft.so.10.bak && \
# Add new keys and reupdate
RUN curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub | apt-key add - && \
curl -fsSL https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \
echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \
echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \
rm -rf /var/lib/apt/lists/* && \
mv /opt/conda/lib/libcufft.so.10 /opt/libcufft.so.10.bak && \
mv /opt/conda/lib/libcurand.so.10 /opt/libcurand.so.10.bak && \
mv /opt/conda/lib/libcublas.so.11 /opt/libcublas.so.11.bak && \
mv /opt/conda/lib/libnvrtc.so.11.0 /opt/libnvrtc.so.11.1.bak && \
mv /opt/conda/lib/libnvToolsExt.so.1 /opt/libnvToolsExt.so.1.bak && \
mv /opt/conda/lib/libcudart.so.11.0 /opt/libcudart.so.11.0.bak
# mv /opt/conda/lib/libnvToolsExt.so.1 /opt/libnvToolsExt.so.1.bak && \
mv /opt/conda/lib/libcudart.so.11.0 /opt/libcudart.so.11.0.bak && \
apt-get update && apt-get -y upgrade
# cmake
RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
cd /opt && \
tar -zxvf cmake-3.18.0.tar.gz && \
@ -45,11 +56,7 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
rm -rf cmake-3.18.0.tar.gz && \
find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
cd -
#kaldiio
RUN pip install kaldiio
# flac
RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \
cd /opt && \
@ -62,15 +69,8 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz &&
find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
cd -
# graphviz
RUN pip install graphviz
# kaldifeat
RUN git clone https://github.com/csukuangfj/kaldifeat.git /opt/kaldifeat && \
cd /opt/kaldifeat && \
python setup.py install && \
cd -
RUN pip install kaldiio graphviz && \
conda install -y -c pytorch torchaudio=0.7.1
#install k2 from source
RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
@ -79,14 +79,13 @@ RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
cd -
# install lhotse
RUN pip install torchaudio==0.7.2
RUN pip install git+https://github.com/lhotse-speech/lhotse
#RUN pip install lhotse
RUN pip install git+https://github.com/lhotse-speech/lhotse
# install icefall
RUN git clone https://github.com/k2-fsa/icefall && \
cd icefall && \
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

Binary file not shown.

After

Width:  |  Height:  |  Size: 413 KiB

View File

@ -6,3 +6,4 @@ LibriSpeech
tdnn_lstm_ctc
conformer_ctc
lstm_pruned_stateless_transducer

View File

@ -0,0 +1,634 @@
LSTM Transducer
===============
.. hint::
Please scroll down to the bottom of this page to find download links
for pretrained models if you don't want to train a model from scratch.
This tutorial shows you how to train an LSTM transducer model
with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
We use pruned RNN-T to compute the loss.
.. note::
You can find the paper about pruned RNN-T at the following address:
`<https://arxiv.org/abs/2206.13236>`_
The transducer model consists of 3 parts:
- Encoder, a.k.a, the transcription network. We use an LSTM model
- Decoder, a.k.a, the prediction network. We use a stateless model consisting of
``nn.Embedding`` and ``nn.Conv1d``
- Joiner, a.k.a, the joint network.
.. caution::
Contrary to the conventional RNN-T models, we use a stateless decoder.
That is, it has no recurrent connections.
.. hint::
Since the encoder model is an LSTM, not Transformer/Conformer, the
resulting model is suitable for streaming/online ASR.
Which model to use
------------------
Currently, there are two folders about LSTM stateless transducer training:
- ``(1)`` `<https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/lstm_transducer_stateless>`_
This recipe uses only LibriSpeech during training.
- ``(2)`` `<https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/lstm_transducer_stateless2>`_
This recipe uses GigaSpeech + LibriSpeech during training.
``(1)`` and ``(2)`` use the same model architecture. The only difference is that ``(2)`` supports
multi-dataset. Since ``(2)`` uses more data, it has a lower WER than ``(1)`` but it needs
more training time.
We use ``lstm_transducer_stateless2`` as an example below.
.. note::
You need to download the `GigaSpeech <https://github.com/SpeechColab/GigaSpeech>`_ dataset
to run ``(2)``. If you have only ``LibriSpeech`` dataset available, feel free to use ``(1)``.
Data preparation
----------------
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./prepare.sh
# If you use (1), you can **skip** the following command
$ ./prepare_giga_speech.sh
The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
All you need to do is to run it.
.. note::
We encourage you to read ``./prepare.sh``.
The data preparation contains several stages. You can use the following two
options:
- ``--stage``
- ``--stop-stage``
to control which stage(s) should be run. By default, all stages are executed.
For example,
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./prepare.sh --stage 0 --stop-stage 0
means to run only stage 0.
To run stage 2 to stage 5, use:
.. code-block:: bash
$ ./prepare.sh --stage 2 --stop-stage 5
.. hint::
If you have pre-downloaded the `LibriSpeech <https://www.openslr.org/12>`_
dataset and the `musan <http://www.openslr.org/17/>`_ dataset, say,
they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify
the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that
``./prepare.sh`` won't re-download them.
.. note::
All generated files by ``./prepare.sh``, e.g., features, lexicon, etc,
are saved in ``./data`` directory.
We provide the following YouTube video showing how to run ``./prepare.sh``.
.. note::
To get the latest news of `next-gen Kaldi <https://github.com/k2-fsa>`_, please subscribe
the following YouTube channel by `Nadira Povey <https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_:
`<https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_
.. youtube:: ofEIoJL-mGM
Training
--------
Configurable options
~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./lstm_transducer_stateless2/train.py --help
shows you the training options that can be passed from the commandline.
The following options are used quite often:
- ``--full-libri``
If it's True, the training part uses all the training data, i.e.,
960 hours. Otherwise, the training part uses only the subset
``train-clean-100``, which has 100 hours of training data.
.. CAUTION::
The training set is perturbed by speed with two factors: 0.9 and 1.1.
If ``--full-libri`` is True, each epoch actually processes
``3x960 == 2880`` hours of data.
- ``--num-epochs``
It is the number of epochs to train. For instance,
``./lstm_transducer_stateless2/train.py --num-epochs 30`` trains for 30 epochs
and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt``
in the folder ``./lstm_transducer_stateless2/exp``.
- ``--start-epoch``
It's used to resume training.
``./lstm_transducer_stateless2/train.py --start-epoch 10`` loads the
checkpoint ``./lstm_transducer_stateless2/exp/epoch-9.pt`` and starts
training from epoch 10, based on the state from epoch 9.
- ``--world-size``
It is used for multi-GPU single-machine DDP training.
- (a) If it is 1, then no DDP training is used.
- (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training.
The following shows some use cases with it.
**Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and
GPU 2 for training. You can do the following:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ export CUDA_VISIBLE_DEVICES="0,2"
$ ./lstm_transducer_stateless2/train.py --world-size 2
**Use case 2**: You have 4 GPUs and you want to use all of them
for training. You can do the following:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./lstm_transducer_stateless2/train.py --world-size 4
**Use case 3**: You have 4 GPUs but you only want to use GPU 3
for training. You can do the following:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ export CUDA_VISIBLE_DEVICES="3"
$ ./lstm_transducer_stateless2/train.py --world-size 1
.. caution::
Only multi-GPU single-machine DDP training is implemented at present.
Multi-GPU multi-machine DDP training will be added later.
- ``--max-duration``
It specifies the number of seconds over all utterances in a
batch, before **padding**.
If you encounter CUDA OOM, please reduce it.
.. HINT::
Due to padding, the number of seconds of all utterances in a
batch will usually be larger than ``--max-duration``.
A larger value for ``--max-duration`` may cause OOM during training,
while a smaller value may increase the training time. You have to
tune it.
- ``--giga-prob``
The probability to select a batch from the ``GigaSpeech`` dataset.
Note: It is available only for ``(2)``.
Pre-configured options
~~~~~~~~~~~~~~~~~~~~~~
There are some training options, e.g., weight decay,
number of warmup steps, results dir, etc,
that are not passed from the commandline.
They are pre-configured by the function ``get_params()`` in
`lstm_transducer_stateless2/train.py <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/lstm_transducer_stateless2/train.py>`_
You don't need to change these pre-configured parameters. If you really need to change
them, please modify ``./lstm_transducer_stateless2/train.py`` directly.
Training logs
~~~~~~~~~~~~~
Training logs and checkpoints are saved in ``lstm_transducer_stateless2/exp``.
You will find the following files in that directory:
- ``epoch-1.pt``, ``epoch-2.pt``, ...
These are checkpoint files saved at the end of each epoch, containing model
``state_dict`` and optimizer ``state_dict``.
To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
.. code-block:: bash
$ ./lstm_transducer_stateless2/train.py --start-epoch 11
- ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ...
These are checkpoint files saved every ``--save-every-n`` batches,
containing model ``state_dict`` and optimizer ``state_dict``.
To resume training from some checkpoint, say ``checkpoint-436000``, you can use:
.. code-block:: bash
$ ./lstm_transducer_stateless2/train.py --start-batch 436000
- ``tensorboard/``
This folder contains tensorBoard logs. Training loss, validation loss, learning
rate, etc, are recorded in these logs. You can visualize them by:
.. code-block:: bash
$ cd lstm_transducer_stateless2/exp/tensorboard
$ tensorboard dev upload --logdir . --description "LSTM transducer training for LibriSpeech with icefall"
It will print something like below:
.. code-block::
TensorFlow installation not found - running with reduced feature set.
Upload started and will continue reading any new data as it's added to the logdir.
To stop uploading, press Ctrl-C.
New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/cj2vtPiwQHKN9Q1tx6PTpg/
[2022-09-20T15:50:50] Started scanning logdir.
Uploading 4468 scalars...
[2022-09-20T15:53:02] Total uploaded: 210171 scalars, 0 tensors, 0 binary objects
Listening for new data in logdir...
Note there is a URL in the above output. Click it and you will see
the following screenshot:
.. figure:: images/librispeech-lstm-transducer-tensorboard-log.png
:width: 600
:alt: TensorBoard screenshot
:align: center
:target: https://tensorboard.dev/experiment/lzGnETjwRxC3yghNMd4kPw/
TensorBoard screenshot.
.. hint::
If you don't have access to google, you can use the following command
to view the tensorboard log locally:
.. code-block:: bash
cd lstm_transducer_stateless2/exp/tensorboard
tensorboard --logdir . --port 6008
It will print the following message:
.. code-block::
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit)
Now start your browser and go to `<http://localhost:6008>`_ to view the tensorboard
logs.
- ``log/log-train-xxxx``
It is the detailed training log in text format, same as the one
you saw printed to the console during training.
Usage example
~~~~~~~~~~~~~
You can use the following command to start the training using 8 GPUs:
.. code-block:: bash
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./lstm_transducer_stateless2/train.py \
--world-size 8 \
--num-epochs 35 \
--start-epoch 1 \
--full-libri 1 \
--exp-dir lstm_transducer_stateless2/exp \
--max-duration 500 \
--use-fp16 0 \
--lr-epochs 10 \
--num-workers 2 \
--giga-prob 0.9
Decoding
--------
The decoding part uses checkpoints saved by the training part, so you have
to run the training part first.
.. hint::
There are two kinds of checkpoints:
- (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end
of each epoch. You can pass ``--epoch`` to
``lstm_transducer_stateless2/decode.py`` to use them.
- (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved
every ``--save-every-n`` batches. You can pass ``--iter`` to
``lstm_transducer_stateless2/decode.py`` to use them.
We suggest that you try both types of checkpoints and choose the one
that produces the lowest WERs.
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./lstm_transducer_stateless2/decode.py --help
shows the options for decoding.
The following shows two examples:
.. code-block:: bash
for m in greedy_search fast_beam_search modified_beam_search; do
for epoch in 17; do
for avg in 1 2; do
./lstm_transducer_stateless2/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir lstm_transducer_stateless2/exp \
--max-duration 600 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method $m \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--beam-size 4
done
done
done
.. code-block:: bash
for m in greedy_search fast_beam_search modified_beam_search; do
for iter in 474000; do
for avg in 8 10 12 14 16 18; do
./lstm_transducer_stateless2/decode.py \
--iter $iter \
--avg $avg \
--exp-dir lstm_transducer_stateless2/exp \
--max-duration 600 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method $m \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--beam-size 4
done
done
done
Export models
-------------
`lstm_transducer_stateless2/export.py <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/lstm_transducer_stateless2/export.py>`_ supports exporting checkpoints from ``lstm_transducer_stateless2/exp`` in the following ways.
Export ``model.state_dict()``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Checkpoints saved by ``lstm_transducer_stateless2/train.py`` also include
``optimizer.state_dict()``. It is useful for resuming training. But after training,
we are interested only in ``model.state_dict()``. You can use the following
command to extract ``model.state_dict()``.
.. code-block:: bash
# Assume that --iter 468000 --avg 16 produces the smallest WER
# (You can get such information after running ./lstm_transducer_stateless2/decode.py)
iter=468000
avg=16
./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--iter $iter \
--avg $avg
It will generate a file ``./lstm_transducer_stateless2/exp/pretrained.pt``.
.. hint::
To use the generated ``pretrained.pt`` for ``lstm_transducer_stateless2/decode.py``,
you can run:
.. code-block:: bash
cd lstm_transducer_stateless2/exp
ln -s pretrained epoch-9999.pt
And then pass ``--epoch 9999 --avg 1 --use-averaged-model 0`` to
``./lstm_transducer_stateless2/decode.py``.
To use the exported model with ``./lstm_transducer_stateless2/pretrained.py``, you
can run:
.. code-block:: bash
./lstm_transducer_stateless2/pretrained.py \
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
Export model using ``torch.jit.trace()``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
iter=468000
avg=16
./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--iter $iter \
--avg $avg \
--jit-trace 1
It will generate 3 files:
- ``./lstm_transducer_stateless2/exp/encoder_jit_trace.pt``
- ``./lstm_transducer_stateless2/exp/decoder_jit_trace.pt``
- ``./lstm_transducer_stateless2/exp/joiner_jit_trace.pt``
To use the generated files with ``./lstm_transducer_stateless2/jit_pretrained``:
.. code-block:: bash
./lstm_transducer_stateless2/jit_pretrained.py \
--bpe-model ./data/lang_bpe_500/bpe.model \
--encoder-model-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace.pt \
--decoder-model-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace.pt \
--joiner-model-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace.pt \
/path/to/foo.wav \
/path/to/bar.wav
.. hint::
Please see `<https://k2-fsa.github.io/sherpa/python/streaming_asr/lstm/english/server.html>`_
for how to use the exported models in ``sherpa``.
Export model for ncnn
~~~~~~~~~~~~~~~~~~~~~
We support exporting pretrained LSTM transducer models to
`ncnn <https://github.com/tencent/ncnn>`_ using
`pnnx <https://github.com/Tencent/ncnn/tree/master/tools/pnnx>`_.
First, let us install a modified version of ``ncnn``:
.. code-block:: bash
git clone https://github.com/csukuangfj/ncnn
cd ncnn
git submodule update --recursive --init
python3 setup.py bdist_wheel
ls -lh dist/
pip install ./dist/*.whl
# now build pnnx
cd tools/pnnx
mkdir build
cd build
make -j4
export PATH=$PWD/src:$PATH
./src/pnnx
.. note::
We assume that you have added the path to the binary ``pnnx`` to the
environment variable ``PATH``.
Second, let us export the model using ``torch.jit.trace()`` that is suitable
for ``pnnx``:
.. code-block:: bash
iter=468000
avg=16
./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--iter $iter \
--avg $avg \
--pnnx 1
It will generate 3 files:
- ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt``
- ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt``
- ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt``
Third, convert torchscript model to ``ncnn`` format:
.. code-block::
pnnx ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt
pnnx ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt
pnnx ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt
It will generate the following files:
- ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param``
- ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin``
- ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param``
- ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin``
- ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param``
- ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin``
To use the above generated files, run:
.. code-block:: bash
./lstm_transducer_stateless2/ncnn-decode.py \
--bpe-model-filename ./data/lang_bpe_500/bpe.model \
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin \
/path/to/foo.wav
.. code-block:: bash
./lstm_transducer_stateless2/streaming-ncnn-decode.py \
--bpe-model-filename ./data/lang_bpe_500/bpe.model \
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin \
/path/to/foo.wav
To use the above generated files in C++, please see
`<https://github.com/k2-fsa/sherpa-ncnn>`_
It is able to generate a static linked executable that can be run on Linux, Windows,
macOS, Raspberry Pi, etc, without external dependencies.
Download pretrained models
--------------------------
If you don't want to train from scratch, you can download the pretrained models
by visiting the following links:
- `<https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03>`_
- `<https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18>`_
See `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>`_
for the details of the above pretrained models
You can find more usages of the pretrained models in
`<https://k2-fsa.github.io/sherpa/python/streaming_asr/lstm/index.html>`_

View File

@ -62,6 +62,13 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.

View File

@ -248,7 +248,9 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
src = residual + self.dropout(
self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
)
if not self.normalize_before:
src = self.norm_conv(src)
@ -879,11 +881,16 @@ class ConvolutionModule(nn.Module):
)
self.activation = Swish()
def forward(self, x: Tensor) -> Tensor:
def forward(
self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
@ -897,6 +904,8 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))

View File

@ -248,7 +248,9 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
src = residual + self.dropout(
self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
)
if not self.normalize_before:
src = self.norm_conv(src)
@ -879,11 +881,16 @@ class ConvolutionModule(nn.Module):
)
self.activation = Swish()
def forward(self, x: Tensor) -> Tensor:
def forward(
self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
@ -897,6 +904,8 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))

View File

@ -62,6 +62,13 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.

View File

@ -62,6 +62,13 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.

View File

@ -246,7 +246,9 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
src = residual + self.dropout(
self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
)
if not self.normalize_before:
src = self.norm_conv(src)
@ -877,11 +879,16 @@ class ConvolutionModule(nn.Module):
)
self.activation = Swish()
def forward(self, x: Tensor) -> Tensor:
def forward(
self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
@ -895,6 +902,8 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
# x is (batch, channels, time)
x = x.permute(0, 2, 1)

View File

@ -62,6 +62,13 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.

View File

@ -63,6 +63,13 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.

View File

@ -43,7 +43,7 @@ torch.set_num_interop_threads(1)
def compute_fbank_alimeeting(num_mel_bins: int = 80):
src_dir = Path("data/manifests")
src_dir = Path("data/manifests/alimeeting")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
@ -63,6 +63,13 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.

View File

@ -30,9 +30,11 @@ with word segmenting:
import argparse
import paddle
import jieba
from tqdm import tqdm
paddle.enable_static()
jieba.enable_paddle()

View File

@ -107,7 +107,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
# Prepare text.
# Note: in Linux, you can install jq with the following command:
# wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
gunzip -c data/manifests/alimeeting/supervisions_train.jsonl.gz \
gunzip -c data/manifests/alimeeting/alimeeting_supervisions_train.jsonl.gz \
| jq ".text" | sed 's/"//g' \
| ./local/text2token.py -t "char" > $lang_char_dir/text

View File

@ -253,7 +253,9 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
src = residual + self.dropout(
self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
)
if not self.normalize_before:
src = self.norm_conv(src)
@ -890,11 +892,16 @@ class ConvolutionModule(nn.Module):
)
self.activation = Swish()
def forward(self, x: Tensor) -> Tensor:
def forward(
self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
@ -908,6 +915,8 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
if self.use_batchnorm:
x = self.norm(x)

View File

@ -62,6 +62,13 @@ def preprocess_giga_speech():
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
for partition, m in manifests.items():
logging.info(f"Processing {partition}")
raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"

View File

@ -26,6 +26,7 @@ The following table lists the differences among them.
| `conv_emformer_transducer_stateless` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer for streaming ASR + mechanisms in reworked model |
| `conv_emformer_transducer_stateless2` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer with simplified memory for streaming ASR + mechanisms in reworked model |
| `lstm_transducer_stateless` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model |
| `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).

View File

@ -1,14 +1,174 @@
## Results
#### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T)
### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + gradient filter)
[lstm_transducer_stateless](./lstm_transducer_stateless)
#### [lstm_transducer_stateless3](./lstm_transducer_stateless3)
It implements LSTM model with mechanisms in reworked model for streaming ASR.
Gradient filter is applied inside each lstm module to stabilize the training.
See <https://github.com/k2-fsa/icefall/pull/564> for more details.
##### training on full librispeech
This model contains 12 encoder layers (LSTM module + Feedforward module). The number of model parameters is 84689496.
The WERs are:
| | test-clean | test-other | comment | decoding mode |
|-------------------------------------|------------|------------|----------------------|----------------------|
| greedy search (max sym per frame 1) | 3.66 | 9.51 | --epoch 40 --avg 15 | simulated streaming |
| greedy search (max sym per frame 1) | 3.66 | 9.48 | --epoch 40 --avg 15 | streaming |
| fast beam search | 3.55 | 9.33 | --epoch 40 --avg 15 | simulated streaming |
| fast beam search | 3.57 | 9.25 | --epoch 40 --avg 15 | streaming |
| modified beam search | 3.55 | 9.28 | --epoch 40 --avg 15 | simulated streaming |
| modified beam search | 3.54 | 9.25 | --epoch 40 --avg 15 | streaming |
Note: `simulated streaming` indicates feeding full utterance during decoding, while `streaming` indicates feeding certain number of frames at each time.
The training command is:
```bash
./lstm_transducer_stateless3/train.py \
--world-size 4 \
--num-epochs 40 \
--start-epoch 1 \
--exp-dir lstm_transducer_stateless3/exp \
--full-libri 1 \
--max-duration 500 \
--master-port 12325 \
--num-encoder-layers 12 \
--grad-norm-threshold 25.0 \
--rnn-hidden-size 1024
```
The tensorboard log can be found at
<https://tensorboard.dev/experiment/caNPyr5lT8qAl9qKsXEeEQ/>
The simulated streaming decoding command using greedy search, fast beam search, and modified beam search is:
```bash
for decoding_method in greedy_search fast_beam_search modified_beam_search; do
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 15 \
--exp-dir lstm_transducer_stateless3/exp \
--max-duration 600 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method $decoding_method \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--beam-size 4
done
```
The streaming decoding command using greedy search, fast beam search, and modified beam search is:
```bash
for decoding_method in greedy_search fast_beam_search modified_beam_search; do
./lstm_transducer_stateless3/streaming_decode.py \
--epoch 40 \
--avg 15 \
--exp-dir lstm_transducer_stateless3/exp \
--max-duration 600 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method $decoding_method \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--beam-size 4
done
```
Pretrained models, training logs, decoding logs, and decoding results
are available at
<https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless3-2022-09-28>
### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + multi-dataset)
#### [lstm_transducer_stateless2](./lstm_transducer_stateless2)
See <https://github.com/k2-fsa/icefall/pull/558> for more details.
The WERs are:
| | test-clean | test-other | comment |
|-------------------------------------|------------|------------|-------------------------|
| greedy search (max sym per frame 1) | 2.78 | 7.36 | --iter 468000 --avg 16 |
| modified_beam_search | 2.73 | 7.15 | --iter 468000 --avg 16 |
| fast_beam_search | 2.76 | 7.31 | --iter 468000 --avg 16 |
| greedy search (max sym per frame 1) | 2.77 | 7.35 | --iter 472000 --avg 18 |
| modified_beam_search | 2.75 | 7.08 | --iter 472000 --avg 18 |
| fast_beam_search | 2.77 | 7.29 | --iter 472000 --avg 18 |
The training command is:
```bash
#!/usr/bin/env bash
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./lstm_transducer_stateless2/train.py \
--world-size 8 \
--num-epochs 35 \
--start-epoch 1 \
--full-libri 1 \
--exp-dir lstm_transducer_stateless2/exp \
--max-duration 500 \
--use-fp16 0 \
--lr-epochs 10 \
--num-workers 2 \
--giga-prob 0.9
```
**Note**: It was killed manually after getting `epoch-18.pt`. Also, we resumed
training after getting `epoch-9.pt`.
The tensorboard log can be found at
<https://tensorboard.dev/experiment/1ziQ2LFmQY2mt4dlUr5dyA/>
The decoding command is
```bash
for m in greedy_search fast_beam_search modified_beam_search; do
for iter in 472000; do
for avg in 8 10 12 14 16 18; do
./lstm_transducer_stateless2/decode.py \
--iter $iter \
--avg $avg \
--exp-dir lstm_transducer_stateless2/exp \
--max-duration 600 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method $m \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--beam-size 4
done
done
done
```
Pretrained models, training logs, decoding logs, and decoding results
are available at
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03>
### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T)
#### [lstm_transducer_stateless](./lstm_transducer_stateless)
It implements LSTM model with mechanisms in reworked model for streaming ASR.
See <https://github.com/k2-fsa/icefall/pull/479> for more details.
#### training on full librispeech
##### training on full librispeech
This model contains 12 encoder layers (LSTM module + Feedforward module). The number of model parameters is 84689496.
@ -95,7 +255,7 @@ It is modified from [torchaudio](https://github.com/pytorch/audio).
See <https://github.com/k2-fsa/icefall/pull/440> for more details.
#### With lower latency setup, training on full librispeech
##### With lower latency setup, training on full librispeech
In this model, the lengths of chunk and right context are 32 frames (i.e., 0.32s) and 8 frames (i.e., 0.08s), respectively.
@ -246,7 +406,7 @@ Pretrained models, training logs, decoding logs, and decoding results
are available at
<https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05>
#### With higher latency setup, training on full librispeech
##### With higher latency setup, training on full librispeech
In this model, the lengths of chunk and right context are 64 frames (i.e., 0.64s) and 16 frames (i.e., 0.16s), respectively.
@ -781,14 +941,14 @@ Pre-trained models, training and decoding logs, and decoding results are availab
### LibriSpeech BPE training results (Pruned Stateless Conv-Emformer RNN-T)
[conv_emformer_transducer_stateless](./conv_emformer_transducer_stateless)
#### [conv_emformer_transducer_stateless](./conv_emformer_transducer_stateless)
It implements [Emformer](https://arxiv.org/abs/2010.10759) augmented with convolution module for streaming ASR.
It is modified from [torchaudio](https://github.com/pytorch/audio).
See <https://github.com/k2-fsa/icefall/pull/389> for more details.
#### Training on full librispeech
##### Training on full librispeech
In this model, the lengths of chunk and right context are 32 frames (i.e., 0.32s) and 8 frames (i.e., 0.08s), respectively.
@ -941,7 +1101,7 @@ are available at
### LibriSpeech BPE training results (Pruned Stateless Emformer RNN-T)
[pruned_stateless_emformer_rnnt2](./pruned_stateless_emformer_rnnt2)
#### [pruned_stateless_emformer_rnnt2](./pruned_stateless_emformer_rnnt2)
Use <https://github.com/k2-fsa/icefall/pull/390>.
@ -1009,7 +1169,7 @@ results at:
### LibriSpeech BPE training results (Pruned Stateless Transducer 5)
[pruned_transducer_stateless5](./pruned_transducer_stateless5)
#### [pruned_transducer_stateless5](./pruned_transducer_stateless5)
Same as `Pruned Stateless Transducer 2` but with more layers.
@ -1022,7 +1182,7 @@ The notations `large` and `medium` below are from the [Conformer](https://arxiv.
paper, where the large model has about 118 M parameters and the medium model
has 30.8 M parameters.
#### Large
##### Large
Number of model parameters 118129516 (i.e, 118.13 M).
@ -1082,7 +1242,7 @@ results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07>
#### Medium
##### Medium
Number of model parameters 30896748 (i.e, 30.9 M).
@ -1142,7 +1302,7 @@ results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-M-2022-07-07>
#### Baseline-2
##### Baseline-2
It has 88.98 M parameters. Compared to the model in pruned_transducer_stateless2, its has more
layers (24 v.s 12) but a narrower model (1536 feedforward dim and 384 encoder dim vs 2048 feed forward dim and 512 encoder dim).
@ -1203,13 +1363,13 @@ results at:
### LibriSpeech BPE training results (Pruned Stateless Transducer 4)
[pruned_transducer_stateless4](./pruned_transducer_stateless4)
#### [pruned_transducer_stateless4](./pruned_transducer_stateless4)
This version saves averaged model during training, and decodes with averaged model.
See <https://github.com/k2-fsa/icefall/issues/337> for details about the idea of model averaging.
#### Training on full librispeech
##### Training on full librispeech
See <https://github.com/k2-fsa/icefall/pull/344>
@ -1285,7 +1445,7 @@ Pretrained models, training logs, decoding logs, and decoding results
are available at
<https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless4-2022-06-03>
#### Training on train-clean-100
##### Training on train-clean-100
See <https://github.com/k2-fsa/icefall/pull/344>
@ -1322,7 +1482,7 @@ The tensorboard log can be found at
### LibriSpeech BPE training results (Pruned Stateless Transducer 3, 2022-04-29)
[pruned_transducer_stateless3](./pruned_transducer_stateless3)
#### [pruned_transducer_stateless3](./pruned_transducer_stateless3)
Same as `Pruned Stateless Transducer 2` but using the XL subset from
[GigaSpeech](https://github.com/SpeechColab/GigaSpeech) as extra training data.
@ -1536,10 +1696,10 @@ can be found at
### LibriSpeech BPE training results (Pruned Transducer 2)
[pruned_transducer_stateless2](./pruned_transducer_stateless2)
#### [pruned_transducer_stateless2](./pruned_transducer_stateless2)
This is with a reworked version of the conformer encoder, with many changes.
#### Training on fulll librispeech
##### Training on full librispeech
Using commit `34aad74a2c849542dd5f6359c9e6b527e8782fd6`.
See <https://github.com/k2-fsa/icefall/pull/288>
@ -1588,7 +1748,7 @@ can be found at
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29>
#### Training on train-clean-100:
##### Training on train-clean-100:
Trained with 1 job:
```

View File

@ -253,7 +253,9 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
src = residual + self.dropout(
self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
)
if not self.normalize_before:
src = self.norm_conv(src)
@ -890,11 +892,16 @@ class ConvolutionModule(nn.Module):
)
self.activation = Swish()
def forward(self, x: Tensor) -> Tensor:
def forward(
self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
@ -908,6 +915,8 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
if self.use_batchnorm:
x = self.norm(x)

View File

@ -322,7 +322,7 @@ def main():
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
modified=params.num_classes > 500,
device=device,
)

View File

@ -268,7 +268,9 @@ class ConformerEncoderLayer(nn.Module):
src = src + self.dropout(src_att)
# convolution module
src = src + self.dropout(self.conv_module(src))
src = src + self.dropout(
self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
)
# feed forward module
src = src + self.dropout(self.feed_forward(src))
@ -921,11 +923,16 @@ class ConvolutionModule(nn.Module):
initial_scale=0.25,
)
def forward(self, x: Tensor) -> Tensor:
def forward(
self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
@ -941,6 +948,8 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
x = self.deriv_balancer2(x)

View File

@ -247,7 +247,9 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
src = residual + self.dropout(
self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
)
if not self.normalize_before:
src = self.norm_conv(src)
@ -878,11 +880,16 @@ class ConvolutionModule(nn.Module):
)
self.activation = Swish()
def forward(self, x: Tensor) -> Tensor:
def forward(
self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
@ -896,6 +903,8 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))

View File

@ -66,6 +66,13 @@ def compute_fbank_librispeech():
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.

View File

@ -65,6 +65,8 @@ def compute_fbank_musan():
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
musan_cuts_path = output_dir / "musan_cuts.jsonl.gz"

View File

@ -68,6 +68,13 @@ def preprocess_giga_speech():
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
for partition, m in manifests.items():
logging.info(f"Processing {partition}")
raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}"

View File

@ -116,6 +116,8 @@ class RNN(EncoderInterface):
Period of auxiliary layers used for random combiner during training.
If set to 0, will not use the random combiner (Default).
You can set a positive integer to use the random combiner, e.g., 3.
is_pnnx:
True to make this class exportable via PNNX.
"""
def __init__(
@ -129,6 +131,7 @@ class RNN(EncoderInterface):
dropout: float = 0.1,
layer_dropout: float = 0.075,
aux_layer_period: int = 0,
is_pnnx: bool = False,
) -> None:
super(RNN, self).__init__()
@ -142,7 +145,13 @@ class RNN(EncoderInterface):
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_embed = Conv2dSubsampling(
num_features,
d_model,
is_pnnx=is_pnnx,
)
self.is_pnnx = is_pnnx
self.num_encoder_layers = num_encoder_layers
self.d_model = d_model
@ -209,7 +218,13 @@ class RNN(EncoderInterface):
# lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 3) >> 1) - 1) >> 1
if not self.is_pnnx:
lengths = (((x_lens - 3) >> 1) - 1) >> 1
else:
lengths1 = torch.floor((x_lens - 3) / 2)
lengths = torch.floor((lengths1 - 1) / 2)
lengths = lengths.to(x_lens)
if not torch.jit.is_tracing():
assert x.size(0) == lengths.max().item()
@ -359,7 +374,7 @@ class RNNEncoderLayer(nn.Module):
# for cell state
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
src_lstm, new_states = self.lstm(src, states)
src = src + self.dropout(src_lstm)
src = self.dropout(src_lstm) + src
# feed forward module
src = src + self.dropout(self.feed_forward(src))
@ -505,6 +520,7 @@ class Conv2dSubsampling(nn.Module):
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
is_pnnx: bool = False,
) -> None:
"""
Args:
@ -517,6 +533,9 @@ class Conv2dSubsampling(nn.Module):
Number of channels in layer1
layer1_channels:
Number of channels in layer2
is_pnnx:
True if we are converting the model to PNNX format.
False otherwise.
"""
assert in_channels >= 9
super().__init__()
@ -559,6 +578,10 @@ class Conv2dSubsampling(nn.Module):
channel_dim=-1, min_positive=0.45, max_positive=0.55
)
# ncnn supports only batch size == 1
self.is_pnnx = is_pnnx
self.conv_out_dim = self.out.weight.shape[1]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
@ -572,9 +595,15 @@ class Conv2dSubsampling(nn.Module):
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if torch.jit.is_tracing() and self.is_pnnx:
x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim)
x = self.out(x)
else:
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-3)//2-1))//2, odim)
x = self.out_norm(x)
x = self.out_balancer(x)
@ -773,7 +802,7 @@ class RandomCombine(nn.Module):
"""
logprobs = (
torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device)
* self.stddev
* self.stddev # noqa
)
logprobs[:, -1] += self.final_log_weight
return logprobs.softmax(dim=1)

View File

@ -0,0 +1 @@
../pruned_transducer_stateless3/asr_datamodule.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1,827 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./lstm_transducer_stateless2/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./lstm_transducer_stateless2/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./lstm_transducer_stateless2/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./lstm_transducer_stateless2/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./lstm_transducer_stateless2/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./lstm_transducer_stateless2/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./lstm_transducer_stateless2/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import AsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from librispeech import LibriSpeech
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="lstm_transducer_stateless2/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
# tail padding here to alleviate the tail deletion problem
num_tail_padded_frames = 35
feature = torch.nn.functional.pad(
feature,
(0, 0, 0, num_tail_padded_frames),
mode="constant",
value=LOG_EPS,
)
feature_lens += num_tail_padded_frames
encoder_out, encoder_out_lens, _ = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params, enable_giga=False)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device),
strict=False,
)
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
),
strict=False,
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
),
strict=False,
)
model.to(device)
model.eval()
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
asr_datamodule = AsrDataModule(args)
librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts)
test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/decoder.py

View File

@ -0,0 +1 @@
../transducer_stateless/encoder_interface.py

View File

@ -0,0 +1,427 @@
#!/usr/bin/env python3
# flake8: noqa
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to torchscript model using torch.jit.trace()
./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 35 \
--avg 10 \
--jit-trace 1
It will generate 3 files: `encoder_jit_trace.pt`,
`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`.
(2) Export `model.state_dict()`
./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 35 \
--avg 10
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
To use the generated file with `lstm_transducer_stateless2/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./lstm_transducer_stateless2/decode.py \
--exp-dir ./lstm_transducer_stateless2/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
# You will find the pre-trained models in icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless3/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit-trace",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.trace.
It will generate 3 files:
- encoder_jit_trace.pt
- decoder_jit_trace.pt
- joiner_jit_trace.pt
Check ./jit_pretrained.py for how to use them.
""",
)
parser.add_argument(
"--pnnx",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.trace for later
converting to PNNX. It will generate 3 files:
- encoder_jit_trace-pnnx.pt
- decoder_jit_trace-pnnx.pt
- joiner_jit_trace-pnnx.pt
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
return parser
def export_encoder_model_jit_trace(
encoder_model: nn.Module,
encoder_filename: str,
) -> None:
"""Export the given encoder model with torch.jit.trace()
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported model.
"""
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
states = encoder_model.get_init_states()
traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
traced_model.save(encoder_filename)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_jit_trace(
decoder_model: nn.Module,
decoder_filename: str,
) -> None:
"""Export the given decoder model with torch.jit.trace()
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The input decoder model
decoder_filename:
The filename to save the exported model.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = torch.tensor([False])
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
traced_model.save(decoder_filename)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_jit_trace(
joiner_model: nn.Module,
joiner_filename: str,
) -> None:
"""Export the given joiner model with torch.jit.trace()
Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
Args:
joiner_model:
The input joiner model
joiner_filename:
The filename to save the exported model.
"""
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
traced_model.save(joiner_filename)
logging.info(f"Saved to {joiner_filename}")
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
if params.pnnx:
params.is_pnnx = params.pnnx
logging.info("For PNNX")
logging.info("About to create model")
model = get_transducer_model(params, enable_giga=False)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device),
strict=False,
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device),
strict=False,
)
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
),
strict=False,
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
),
strict=False,
)
model.to("cpu")
model.eval()
if params.pnnx:
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.trace()")
encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
elif params.jit_trace is True:
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.trace()")
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
else:
logging.info("Not using torchscript")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless3/gigaspeech.py

View File

@ -0,0 +1,323 @@
#!/usr/bin/env python3
# flake8: noqa
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models, either exported by `torch.jit.trace()`
or by `torch.jit.script()`, and uses them to decode waves.
You can use the following command to get the exported models:
./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--jit-trace 1
Usage of this script:
./lstm_transducer_stateless2/jit_pretrained.py \
--encoder-model-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace.pt \
--decoder-model-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace.pt \
--joiner-model-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder torchscript model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder torchscript model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner torchscript model. ",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="Context size of the decoder model",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
decoder: torch.jit.ScriptModule,
joiner: torch.jit.ScriptModule,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
context_size: int,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
decoder:
The decoder model.
joiner:
The joiner model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
A 1-D tensor of shape (N,).
context_size:
The context size of the decoder model.
Returns:
Return the decoded results for each utterance.
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
device = encoder_out.device
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (N, context_size)
decoder_out = decoder(
decoder_input,
need_pad=torch.tensor([False]),
).squeeze(1)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = packed_encoder_out.data[start:end]
current_encoder_out = current_encoder_out
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = joiner(
current_encoder_out,
decoder_out,
)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = decoder(
decoder_input,
need_pad=torch.tensor([False]),
)
decoder_out = decoder_out.squeeze(1)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
encoder = torch.jit.load(args.encoder_model_filename)
decoder = torch.jit.load(args.decoder_model_filename)
joiner = torch.jit.load(args.joiner_model_filename)
encoder.eval()
decoder.eval()
joiner.eval()
encoder.to(device)
decoder.to(device)
joiner.to(device)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
expected_sample_rate=args.sample_rate,
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, device=device)
states = encoder.get_init_states(batch_size=features.size(0), device=device)
encoder_out, encoder_out_lens, _ = encoder(
x=features,
x_lens=feature_lengths,
states=states,
)
hyps = greedy_search(
decoder=decoder,
joiner=joiner,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
context_size=args.context_size,
)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = sp.decode(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/joiner.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless3/librispeech.py

View File

@ -0,0 +1 @@
../lstm_transducer_stateless/lstm.py

View File

@ -0,0 +1,241 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
decoder_giga: Optional[nn.Module] = None,
joiner_giga: Optional[nn.Module] = None,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and
(N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output
contains unnormalized probs, i.e., not processed by log-softmax.
encoder_dim:
Output dimension of the encoder network.
decoder_dim:
Output dimension of the decoder network.
joiner_dim:
Input dimension of the joiner network.
vocab_size:
Output dimension of the joiner network.
decoder_giga:
Optional. The decoder network for the GigaSpeech dataset.
joiner_giga:
Optional. The joiner network for the GigaSpeech dataset.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
self.decoder_giga = decoder_giga
self.joiner_giga = joiner_giga
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
if decoder_giga is not None:
self.simple_am_proj_giga = ScaledLinear(
encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj_giga = ScaledLinear(decoder_dim, vocab_size)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
libri: bool = True,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
warmup: float = 1.0,
reduction: str = "sum",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
libri:
True to use the decoder and joiner for the LibriSpeech dataset.
False to use the decoder and joiner for the GigaSpeech dataset.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
warmup:
A value warmup >= 0 that determines which modules are active, values
warmup > 1 "are fully warmed up" and all modules will be active.
reduction:
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
Returns:
Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert reduction in ("sum", "none"), reduction
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens, _ = self.encoder(x, x_lens, warmup=warmup)
assert torch.all(x_lens > 0)
if libri:
decoder = self.decoder
simple_lm_proj = self.simple_lm_proj
simple_am_proj = self.simple_am_proj
joiner = self.joiner
else:
decoder = self.decoder_giga
simple_lm_proj = self.simple_lm_proj_giga
simple_am_proj = self.simple_am_proj_giga
joiner = self.joiner_giga
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
lm = simple_lm_proj(decoder_out)
am = simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction=reduction,
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=joiner.encoder_proj(encoder_out),
lm=joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction=reduction,
)
return (simple_loss, pruned_loss)

View File

@ -0,0 +1,295 @@
#!/usr/bin/env python3
# flake8: noqa
#
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./lstm_transducer_stateless2/ncnn-decode.py \
--bpe-model-filename ./data/lang_bpe_500/bpe.model \
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
--decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
--joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
--joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
./test_wavs/1089-134686-0001.wav
"""
import argparse
import logging
from typing import List
import kaldifeat
import ncnn
import sentencepiece as spm
import torch
import torchaudio
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe-model-filename",
type=str,
help="Path to bpe.model",
)
parser.add_argument(
"--encoder-param-filename",
type=str,
help="Path to encoder.ncnn.param",
)
parser.add_argument(
"--encoder-bin-filename",
type=str,
help="Path to encoder.ncnn.bin",
)
parser.add_argument(
"--decoder-param-filename",
type=str,
help="Path to decoder.ncnn.param",
)
parser.add_argument(
"--decoder-bin-filename",
type=str,
help="Path to decoder.ncnn.bin",
)
parser.add_argument(
"--joiner-param-filename",
type=str,
help="Path to joiner.ncnn.param",
)
parser.add_argument(
"--joiner-bin-filename",
type=str,
help="Path to joiner.ncnn.bin",
)
parser.add_argument(
"sound_filename",
type=str,
help="Path to foo.wav",
)
return parser.parse_args()
class Model:
def __init__(self, args):
self.init_encoder(args)
self.init_decoder(args)
self.init_joiner(args)
def init_encoder(self, args):
encoder_net = ncnn.Net()
encoder_net.opt.use_packing_layout = False
encoder_net.opt.use_fp16_storage = False
encoder_param = args.encoder_param_filename
encoder_model = args.encoder_bin_filename
encoder_net.load_param(encoder_param)
encoder_net.load_model(encoder_model)
self.encoder_net = encoder_net
def init_decoder(self, args):
decoder_param = args.decoder_param_filename
decoder_model = args.decoder_bin_filename
decoder_net = ncnn.Net()
decoder_net.opt.use_packing_layout = False
decoder_net.load_param(decoder_param)
decoder_net.load_model(decoder_model)
self.decoder_net = decoder_net
def init_joiner(self, args):
joiner_param = args.joiner_param_filename
joiner_model = args.joiner_bin_filename
joiner_net = ncnn.Net()
joiner_net.opt.use_packing_layout = False
joiner_net.load_param(joiner_param)
joiner_net.load_model(joiner_model)
self.joiner_net = joiner_net
def run_encoder(self, x, states):
with self.encoder_net.create_extractor() as ex:
ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(x.numpy()).clone())
x_lens = torch.tensor([x.size(0)], dtype=torch.float32)
ex.input("in1", ncnn.Mat(x_lens.numpy()).clone())
ex.input("in2", ncnn.Mat(states[0].numpy()).clone())
ex.input("in3", ncnn.Mat(states[1].numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
ret, ncnn_out1 = ex.extract("out1")
assert ret == 0, ret
ret, ncnn_out2 = ex.extract("out2")
assert ret == 0, ret
ret, ncnn_out3 = ex.extract("out3")
assert ret == 0, ret
encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(
torch.int32
)
hx = torch.from_numpy(ncnn_out2.numpy()).clone()
cx = torch.from_numpy(ncnn_out3.numpy()).clone()
return encoder_out, encoder_out_lens, hx, cx
def run_decoder(self, decoder_input):
assert decoder_input.dtype == torch.int32
with self.decoder_net.create_extractor() as ex:
ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
return decoder_out
def run_joiner(self, encoder_out, decoder_out):
with self.joiner_net.create_extractor() as ex:
ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone()
return joiner_out
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(model: Model, encoder_out: torch.Tensor):
assert encoder_out.ndim == 2
T = encoder_out.size(0)
context_size = 2
blank_id = 0 # hard-code to 0
hyp = [blank_id] * context_size
decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size)
decoder_out = model.run_decoder(decoder_input).squeeze(0)
# print(decoder_out.shape) # (512,)
for t in range(T):
encoder_out_t = encoder_out[t]
joiner_out = model.run_joiner(encoder_out_t, decoder_out)
# print(joiner_out.shape) # [500]
y = joiner_out.argmax(dim=0).tolist()
if y != blank_id:
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = torch.tensor(decoder_input, dtype=torch.int32)
decoder_out = model.run_decoder(decoder_input).squeeze(0)
return hyp[context_size:]
def main():
args = get_args()
logging.info(vars(args))
model = Model(args)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model_filename)
sound_file = args.sound_filename
sample_rate = 16000
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {sound_file}")
wave_samples = read_sound_files(
filenames=[sound_file],
expected_sample_rate=sample_rate,
)[0]
logging.info("Decoding started")
features = fbank(wave_samples)
num_encoder_layers = 12
d_model = 512
rnn_hidden_size = 1024
states = (
torch.zeros(num_encoder_layers, d_model),
torch.zeros(
num_encoder_layers,
rnn_hidden_size,
),
)
encoder_out, encoder_out_lens, hx, cx = model.run_encoder(features, states)
hyp = greedy_search(model, encoder_out)
logging.info(sound_file)
logging.info(sp.decode(hyp))
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1,355 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./lstm_transducer_stateless2/pretrained.py \
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./lstm_transducer_stateless2/pretrained.py \
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
./lstm_transducer_stateless2/pretrained.py \
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./lstm_transducer_stateless2/pretrained.py \
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./lstm_transducer_stateless2/exp/epoch-xx.pt`.
Note: ./lstm_transducer_stateless2/exp/pretrained.pt is generated by
./lstm_transducer_stateless2/export.py
You can find pretrained models by visiting
https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/tree/main/exp
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params, enable_giga=False)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens, _ = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/scaling.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless3/scaling_converter.py

View File

@ -0,0 +1,353 @@
#!/usr/bin/env python3
# flake8: noqa
#
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from typing import List, Optional
import ncnn
import sentencepiece as spm
import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe-model-filename",
type=str,
help="Path to bpe.model",
)
parser.add_argument(
"--encoder-param-filename",
type=str,
help="Path to encoder.ncnn.param",
)
parser.add_argument(
"--encoder-bin-filename",
type=str,
help="Path to encoder.ncnn.bin",
)
parser.add_argument(
"--decoder-param-filename",
type=str,
help="Path to decoder.ncnn.param",
)
parser.add_argument(
"--decoder-bin-filename",
type=str,
help="Path to decoder.ncnn.bin",
)
parser.add_argument(
"--joiner-param-filename",
type=str,
help="Path to joiner.ncnn.param",
)
parser.add_argument(
"--joiner-bin-filename",
type=str,
help="Path to joiner.ncnn.bin",
)
parser.add_argument(
"sound_filename",
type=str,
help="Path to foo.wav",
)
return parser.parse_args()
class Model:
def __init__(self, args):
self.init_encoder(args)
self.init_decoder(args)
self.init_joiner(args)
def init_encoder(self, args):
encoder_net = ncnn.Net()
encoder_net.opt.use_packing_layout = False
encoder_net.opt.use_fp16_storage = False
encoder_param = args.encoder_param_filename
encoder_model = args.encoder_bin_filename
encoder_net.load_param(encoder_param)
encoder_net.load_model(encoder_model)
self.encoder_net = encoder_net
def init_decoder(self, args):
decoder_param = args.decoder_param_filename
decoder_model = args.decoder_bin_filename
decoder_net = ncnn.Net()
decoder_net.opt.use_packing_layout = False
decoder_net.load_param(decoder_param)
decoder_net.load_model(decoder_model)
self.decoder_net = decoder_net
def init_joiner(self, args):
joiner_param = args.joiner_param_filename
joiner_model = args.joiner_bin_filename
joiner_net = ncnn.Net()
joiner_net.opt.use_packing_layout = False
joiner_net.load_param(joiner_param)
joiner_net.load_model(joiner_model)
self.joiner_net = joiner_net
def run_encoder(self, x, states):
with self.encoder_net.create_extractor() as ex:
# ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(x.numpy()).clone())
x_lens = torch.tensor([x.size(0)], dtype=torch.float32)
ex.input("in1", ncnn.Mat(x_lens.numpy()).clone())
ex.input("in2", ncnn.Mat(states[0].numpy()).clone())
ex.input("in3", ncnn.Mat(states[1].numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
ret, ncnn_out1 = ex.extract("out1")
assert ret == 0, ret
ret, ncnn_out2 = ex.extract("out2")
assert ret == 0, ret
ret, ncnn_out3 = ex.extract("out3")
assert ret == 0, ret
encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(
torch.int32
)
hx = torch.from_numpy(ncnn_out2.numpy()).clone()
cx = torch.from_numpy(ncnn_out3.numpy()).clone()
return encoder_out, encoder_out_lens, hx, cx
def run_decoder(self, decoder_input):
assert decoder_input.dtype == torch.int32
with self.decoder_net.create_extractor() as ex:
# ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
return decoder_out
def run_joiner(self, encoder_out, decoder_out):
with self.joiner_net.create_extractor() as ex:
# ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone()
return joiner_out
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def create_streaming_feature_extractor() -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return OnlineFbank(opts)
def greedy_search(
model: Model,
encoder_out: torch.Tensor,
decoder_out: Optional[torch.Tensor] = None,
hyp: Optional[List[int]] = None,
):
assert encoder_out.ndim == 1
context_size = 2
blank_id = 0
if decoder_out is None:
assert hyp is None, hyp
hyp = [blank_id] * context_size
decoder_input = torch.tensor(
hyp, dtype=torch.int32
) # (1, context_size)
decoder_out = model.run_decoder(decoder_input).squeeze(0)
else:
assert decoder_out.ndim == 1
assert hyp is not None, hyp
joiner_out = model.run_joiner(encoder_out, decoder_out)
y = joiner_out.argmax(dim=0).tolist()
if y != blank_id:
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = torch.tensor(decoder_input, dtype=torch.int32)
decoder_out = model.run_decoder(decoder_input).squeeze(0)
return hyp, decoder_out
def main():
args = get_args()
logging.info(vars(args))
model = Model(args)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model_filename)
sound_file = args.sound_filename
sample_rate = 16000
logging.info("Constructing Fbank computer")
online_fbank = create_streaming_feature_extractor()
logging.info(f"Reading sound files: {sound_file}")
wave_samples = read_sound_files(
filenames=[sound_file],
expected_sample_rate=sample_rate,
)[0]
logging.info(wave_samples.shape)
num_encoder_layers = 12
batch_size = 1
d_model = 512
rnn_hidden_size = 1024
states = (
torch.zeros(num_encoder_layers, batch_size, d_model),
torch.zeros(
num_encoder_layers,
batch_size,
rnn_hidden_size,
),
)
hyp = None
decoder_out = None
num_processed_frames = 0
segment = 9
offset = 4
chunk = 3200 # 0.2 second
start = 0
while start < wave_samples.numel():
end = min(start + chunk, wave_samples.numel())
samples = wave_samples[start:end]
start += chunk
online_fbank.accept_waveform(
sampling_rate=sample_rate,
waveform=samples,
)
while online_fbank.num_frames_ready - num_processed_frames >= segment:
frames = []
for i in range(segment):
frames.append(online_fbank.get_frame(num_processed_frames + i))
num_processed_frames += offset
frames = torch.cat(frames, dim=0)
encoder_out, encoder_out_lens, hx, cx = model.run_encoder(
frames, states
)
states = (hx, cx)
hyp, decoder_out = greedy_search(
model, encoder_out.squeeze(0), decoder_out, hyp
)
online_fbank.accept_waveform(
sampling_rate=sample_rate, waveform=torch.zeros(8000, dtype=torch.int32)
)
online_fbank.input_finished()
while online_fbank.num_frames_ready - num_processed_frames >= segment:
frames = []
for i in range(segment):
frames.append(online_fbank.get_frame(num_processed_frames + i))
num_processed_frames += offset
frames = torch.cat(frames, dim=0)
encoder_out, encoder_out_lens, hx, cx = model.run_encoder(
frames, states
)
states = (hx, cx)
hyp, decoder_out = greedy_search(
model, encoder_out.squeeze(0), decoder_out, hyp
)
context_size = 2
logging.info(sound_file)
logging.info(sp.decode(hyp[context_size:]))
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/__init__.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/asr_datamodule.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1,818 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./lstm_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./lstm_transducer_stateless2/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./lstm_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./lstm_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./lstm_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./lstm_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./lstm_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="lstm_transducer_stateless/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
# tail padding here to alleviate the tail deletion problem
num_tail_padded_frames = 35
feature = torch.nn.functional.pad(
feature,
(0, 0, 0, num_tail_padded_frames),
mode="constant",
value=LOG_EPS,
)
feature_lens += num_tail_padded_frames
encoder_out, encoder_out_lens, _ = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/decoder.py

View File

@ -0,0 +1 @@
../transducer_stateless/encoder_interface.py

View File

@ -0,0 +1,388 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to torchscript model using torch.jit.trace()
./lstm_transducer_stateless3/export.py \
--exp-dir ./lstm_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 40 \
--avg 20 \
--jit-trace 1
It will generate 3 files: `encoder_jit_trace.pt`,
`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`.
(2) Export `model.state_dict()`
./lstm_transducer_stateless3/export.py \
--exp-dir ./lstm_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 40 \
--avg 20
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
To use the generated file with `lstm_transducer_stateless3/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./lstm_transducer_stateless3/decode.py \
--exp-dir ./lstm_transducer_stateless3/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18
# You will find the pre-trained model in icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18/exp
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless3/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit-trace",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.trace.
It will generate 3 files:
- encoder_jit_trace.pt
- decoder_jit_trace.pt
- joiner_jit_trace.pt
Check ./jit_pretrained.py for how to use them.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
return parser
def export_encoder_model_jit_trace(
encoder_model: nn.Module,
encoder_filename: str,
) -> None:
"""Export the given encoder model with torch.jit.trace()
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported model.
"""
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
states = encoder_model.get_init_states()
traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
traced_model.save(encoder_filename)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_jit_trace(
decoder_model: nn.Module,
decoder_filename: str,
) -> None:
"""Export the given decoder model with torch.jit.trace()
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The input decoder model
decoder_filename:
The filename to save the exported model.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = torch.tensor([False])
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
traced_model.save(decoder_filename)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_jit_trace(
joiner_model: nn.Module,
joiner_filename: str,
) -> None:
"""Export the given joiner model with torch.jit.trace()
Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
Args:
joiner_model:
The input joiner model
joiner_filename:
The filename to save the exported model.
"""
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
traced_model.save(joiner_filename)
logging.info(f"Saved to {joiner_filename}")
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.jit_trace is True:
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.trace()")
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
else:
logging.info("Not using torchscript")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,322 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models, either exported by `torch.jit.trace()`
or by `torch.jit.script()`, and uses them to decode waves.
You can use the following command to get the exported models:
./lstm_transducer_stateless3/export.py \
--exp-dir ./lstm_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 40 \
--avg 15 \
--jit-trace 1
Usage of this script:
./lstm_transducer_stateless3/jit_pretrained.py \
--encoder-model-filename ./lstm_transducer_stateless3/exp/encoder_jit_trace.pt \
--decoder-model-filename ./lstm_transducer_stateless3/exp/decoder_jit_trace.pt \
--joiner-model-filename ./lstm_transducer_stateless3/exp/joiner_jit_trace.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder torchscript model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder torchscript model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner torchscript model. ",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="Context size of the decoder model",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
decoder: torch.jit.ScriptModule,
joiner: torch.jit.ScriptModule,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
context_size: int,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
decoder:
The decoder model.
joiner:
The joiner model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
A 1-D tensor of shape (N,).
context_size:
The context size of the decoder model.
Returns:
Return the decoded results for each utterance.
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
device = encoder_out.device
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (N, context_size)
decoder_out = decoder(
decoder_input,
need_pad=torch.tensor([False]),
).squeeze(1)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = packed_encoder_out.data[start:end]
current_encoder_out = current_encoder_out
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = joiner(
current_encoder_out,
decoder_out,
)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = decoder(
decoder_input,
need_pad=torch.tensor([False]),
)
decoder_out = decoder_out.squeeze(1)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
encoder = torch.jit.load(args.encoder_model_filename)
decoder = torch.jit.load(args.decoder_model_filename)
joiner = torch.jit.load(args.joiner_model_filename)
encoder.eval()
decoder.eval()
joiner.eval()
encoder.to(device)
decoder.to(device)
joiner.to(device)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
expected_sample_rate=args.sample_rate,
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, device=device)
states = encoder.get_init_states(batch_size=features.size(0), device=device)
encoder_out, encoder_out_lens, _ = encoder(
x=features,
x_lens=feature_lengths,
states=states,
)
hyps = greedy_search(
decoder=decoder,
joiner=joiner,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
context_size=args.context_size,
)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = sp.decode(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/joiner.py

View File

@ -0,0 +1,860 @@
# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
from typing import List, Optional, Tuple
import torch
from encoder_interface import EncoderInterface
from scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
ScaledConv2d,
ScaledLinear,
ScaledLSTM,
)
from torch import nn
LOG_EPSILON = math.log(1e-10)
def unstack_states(
states: Tuple[torch.Tensor, torch.Tensor]
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""
Unstack the lstm states corresponding to a batch of utterances into a list
of states, where the i-th entry is the state from the i-th utterance.
Args:
states:
A tuple of 2 elements.
``states[0]`` is the lstm hidden states, of a batch of utterance.
``states[1]`` is the lstm cell states, of a batch of utterances.
Returns:
A list of states.
``states[i]`` is a tuple of 2 elememts of i-th utterance.
``states[i][0]`` is the lstm hidden states of i-th utterance.
``states[i][1]`` is the lstm cell states of i-th utterance.
"""
hidden_states, cell_states = states
list_hidden_states = hidden_states.unbind(dim=1)
list_cell_states = cell_states.unbind(dim=1)
ans = [
(h.unsqueeze(1), c.unsqueeze(1))
for (h, c) in zip(list_hidden_states, list_cell_states)
]
return ans
def stack_states(
states_list: List[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Stack list of lstm states corresponding to separate utterances into a single
lstm state so that it can be used as an input for lstm when those utterances
are formed into a batch.
Args:
state_list:
Each element in state_list corresponds to the lstm state for a single
utterance.
``states[i]`` is a tuple of 2 elememts of i-th utterance.
``states[i][0]`` is the lstm hidden states of i-th utterance.
``states[i][1]`` is the lstm cell states of i-th utterance.
Returns:
A new state corresponding to a batch of utterances.
It is a tuple of 2 elements.
``states[0]`` is the lstm hidden states, of a batch of utterance.
``states[1]`` is the lstm cell states, of a batch of utterances.
"""
hidden_states = torch.cat([s[0] for s in states_list], dim=1)
cell_states = torch.cat([s[1] for s in states_list], dim=1)
ans = (hidden_states, cell_states)
return ans
class RNN(EncoderInterface):
"""
Args:
num_features (int):
Number of input features.
subsampling_factor (int):
Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa
d_model (int):
Output dimension (default=512).
dim_feedforward (int):
Feedforward dimension (default=2048).
rnn_hidden_size (int):
Hidden dimension for lstm layers (default=1024).
grad_norm_threshold:
For each sequence element in batch, its gradient will be
filtered out if the gradient norm is larger than
`grad_norm_threshold * median`, where `median` is the median
value of gradient norms of all elememts in batch.
num_encoder_layers (int):
Number of encoder layers (default=12).
dropout (float):
Dropout rate (default=0.1).
layer_dropout (float):
Dropout value for model-level warmup (default=0.075).
aux_layer_period (int):
Period of auxiliary layers used for random combiner during training.
If set to 0, will not use the random combiner (Default).
You can set a positive integer to use the random combiner, e.g., 3.
"""
def __init__(
self,
num_features: int,
subsampling_factor: int = 4,
d_model: int = 512,
dim_feedforward: int = 2048,
rnn_hidden_size: int = 1024,
grad_norm_threshold: float = 10.0,
num_encoder_layers: int = 12,
dropout: float = 0.1,
layer_dropout: float = 0.075,
aux_layer_period: int = 0,
) -> None:
super(RNN, self).__init__()
self.num_features = num_features
self.subsampling_factor = subsampling_factor
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.num_encoder_layers = num_encoder_layers
self.d_model = d_model
self.rnn_hidden_size = rnn_hidden_size
encoder_layer = RNNEncoderLayer(
d_model=d_model,
dim_feedforward=dim_feedforward,
rnn_hidden_size=rnn_hidden_size,
grad_norm_threshold=grad_norm_threshold,
dropout=dropout,
layer_dropout=layer_dropout,
)
self.encoder = RNNEncoder(
encoder_layer,
num_encoder_layers,
aux_layers=list(
range(
num_encoder_layers // 3,
num_encoder_layers - 1,
aux_layer_period,
)
)
if aux_layer_period > 0
else None,
)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Args:
x:
The input tensor. Its shape is (N, T, C), where N is the batch size,
T is the sequence length, C is the feature dimension.
x_lens:
A tensor of shape (N,), containing the number of frames in `x`
before padding.
states:
A tuple of 2 tensors (optional). It is for streaming inference.
states[0] is the hidden states of all layers,
with shape of (num_layers, N, d_model);
states[1] is the cell states of all layers,
with shape of (num_layers, N, rnn_hidden_size).
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
Returns:
A tuple of 3 tensors:
- embeddings: its shape is (N, T', d_model), where T' is the output
sequence lengths.
- lengths: a tensor of shape (batch_size,) containing the number of
frames in `embeddings` before padding.
- updated states, whose shape is the same as the input states.
"""
x = self.encoder_embed(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 3) >> 1) - 1) >> 1
if not torch.jit.is_tracing():
assert x.size(0) == lengths.max().item()
if states is None:
x = self.encoder(x, warmup=warmup)[0]
# torch.jit.trace requires returned types to be the same as annotated # noqa
new_states = (torch.empty(0), torch.empty(0))
else:
assert not self.training
assert len(states) == 2
if not torch.jit.is_tracing():
# for hidden state
assert states[0].shape == (
self.num_encoder_layers,
x.size(1),
self.d_model,
)
# for cell state
assert states[1].shape == (
self.num_encoder_layers,
x.size(1),
self.rnn_hidden_size,
)
x, new_states = self.encoder(x, states)
x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
return x, lengths, new_states
@torch.jit.export
def get_init_states(
self, batch_size: int = 1, device: torch.device = torch.device("cpu")
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get model initial states."""
# for rnn hidden states
hidden_states = torch.zeros(
(self.num_encoder_layers, batch_size, self.d_model), device=device
)
cell_states = torch.zeros(
(self.num_encoder_layers, batch_size, self.rnn_hidden_size),
device=device,
)
return (hidden_states, cell_states)
class RNNEncoderLayer(nn.Module):
"""
RNNEncoderLayer is made up of lstm and feedforward networks.
For stable training, in each lstm module, gradient filter
is applied to filter out extremely large elements in batch gradients
and also the module parameters with soft masks.
Args:
d_model:
The number of expected features in the input (required).
dim_feedforward:
The dimension of feedforward network model (default=2048).
rnn_hidden_size:
The hidden dimension of rnn layer.
grad_norm_threshold:
For each sequence element in batch, its gradient will be
filtered out if the gradient norm is larger than
`grad_norm_threshold * median`, where `median` is the median
value of gradient norms of all elememts in batch.
dropout:
The dropout value (default=0.1).
layer_dropout:
The dropout value for model-level warmup (default=0.075).
"""
def __init__(
self,
d_model: int,
dim_feedforward: int,
rnn_hidden_size: int,
grad_norm_threshold: float = 10.0,
dropout: float = 0.1,
layer_dropout: float = 0.075,
) -> None:
super(RNNEncoderLayer, self).__init__()
self.layer_dropout = layer_dropout
self.d_model = d_model
self.rnn_hidden_size = rnn_hidden_size
assert rnn_hidden_size >= d_model, (rnn_hidden_size, d_model)
self.lstm = ScaledLSTM(
input_size=d_model,
hidden_size=rnn_hidden_size,
proj_size=d_model if rnn_hidden_size > d_model else 0,
num_layers=1,
dropout=0.0,
grad_norm_threshold=grad_norm_threshold,
)
self.feed_forward = nn.Sequential(
ScaledLinear(d_model, dim_feedforward),
ActivationBalancer(channel_dim=-1),
DoubleSwish(),
nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
)
self.norm_final = BasicNorm(d_model)
# try to ensure the output is close to zero-mean (or at least, zero-median). # noqa
self.balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
)
self.dropout = nn.Dropout(dropout)
def forward(
self,
src: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Pass the input through the encoder layer.
Args:
src:
The sequence to the encoder layer (required).
Its shape is (S, N, E), where S is the sequence length,
N is the batch size, and E is the feature number.
states:
A tuple of 2 tensors (optional). It is for streaming inference.
states[0] is the hidden states of all layers,
with shape of (1, N, d_model);
states[1] is the cell states of all layers,
with shape of (1, N, rnn_hidden_size).
warmup:
It controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
"""
src_orig = src
warmup_scale = min(0.1 + warmup, 1.0)
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
# completely bypass it.
if self.training:
alpha = (
warmup_scale
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
else 0.1
)
else:
alpha = 1.0
# lstm module
if states is None:
src_lstm = self.lstm(src)[0]
# torch.jit.trace requires returned types be the same as annotated
new_states = (torch.empty(0), torch.empty(0))
else:
assert not self.training
assert len(states) == 2
if not torch.jit.is_tracing():
# for hidden state
assert states[0].shape == (1, src.size(1), self.d_model)
# for cell state
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
src_lstm, new_states = self.lstm(src, states)
src = src + self.dropout(src_lstm)
# feed forward module
src = src + self.dropout(self.feed_forward(src))
src = self.norm_final(self.balancer(src))
if alpha != 1.0:
src = alpha * src + (1 - alpha) * src_orig
return src, new_states
class RNNEncoder(nn.Module):
"""
RNNEncoder is a stack of N encoder layers.
Args:
encoder_layer:
An instance of the RNNEncoderLayer() class (required).
num_layers:
The number of sub-encoder-layers in the encoder (required).
"""
def __init__(
self,
encoder_layer: nn.Module,
num_layers: int,
aux_layers: Optional[List[int]] = None,
) -> None:
super(RNNEncoder, self).__init__()
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
self.num_layers = num_layers
self.d_model = encoder_layer.d_model
self.rnn_hidden_size = encoder_layer.rnn_hidden_size
self.aux_layers: List[int] = []
self.combiner: Optional[nn.Module] = None
if aux_layers is not None:
assert len(set(aux_layers)) == len(aux_layers)
assert num_layers - 1 not in aux_layers
self.aux_layers = aux_layers + [num_layers - 1]
self.combiner = RandomCombine(
num_inputs=len(self.aux_layers),
final_weight=0.5,
pure_prob=0.333,
stddev=2.0,
)
def forward(
self,
src: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Pass the input through the encoder layer in turn.
Args:
src:
The sequence to the encoder layer (required).
Its shape is (S, N, E), where S is the sequence length,
N is the batch size, and E is the feature number.
states:
A tuple of 2 tensors (optional). It is for streaming inference.
states[0] is the hidden states of all layers,
with shape of (num_layers, N, d_model);
states[1] is the cell states of all layers,
with shape of (num_layers, N, rnn_hidden_size).
warmup:
It controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
"""
if states is not None:
assert not self.training
assert len(states) == 2
if not torch.jit.is_tracing():
# for hidden state
assert states[0].shape == (
self.num_layers,
src.size(1),
self.d_model,
)
# for cell state
assert states[1].shape == (
self.num_layers,
src.size(1),
self.rnn_hidden_size,
)
output = src
outputs = []
new_hidden_states = []
new_cell_states = []
for i, mod in enumerate(self.layers):
if states is None:
output = mod(output, warmup=warmup)[0]
else:
layer_state = (
states[0][i : i + 1, :, :], # h: (1, N, d_model)
states[1][i : i + 1, :, :], # c: (1, N, rnn_hidden_size)
)
output, (h, c) = mod(output, layer_state)
new_hidden_states.append(h)
new_cell_states.append(c)
if self.combiner is not None and i in self.aux_layers:
outputs.append(output)
if self.combiner is not None:
output = self.combiner(outputs)
if states is None:
new_states = (torch.empty(0), torch.empty(0))
else:
new_states = (
torch.cat(new_hidden_states, dim=0),
torch.cat(new_cell_states, dim=0),
)
return output, new_states
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = ((T-3)//2-1)//2, which approximates T' == T//4
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(
self,
in_channels: int,
out_channels: int,
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
) -> None:
"""
Args:
in_channels:
Number of channels in. The input shape is (N, T, in_channels).
Caution: It requires: T >= 9, in_channels >= 9.
out_channels
Output dim. The output shape is (N, ((T-3)//2-1)//2, out_channels)
layer1_channels:
Number of channels in layer1
layer1_channels:
Number of channels in layer2
"""
assert in_channels >= 9
super().__init__()
self.conv = nn.Sequential(
ScaledConv2d(
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=0,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
stride=2,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=2,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
)
self.out = ScaledLinear(
layer3_channels * (((in_channels - 3) // 2 - 1) // 2), out_channels
)
# set learn_eps=False because out_norm is preceded by `out`, and `out`
# itself has learned scale, so the extra degree of freedom is not
# needed.
self.out_norm = BasicNorm(out_channels, learn_eps=False)
# constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
Returns:
Return a tensor of shape (N, ((T-3)//2-1)//2, odim)
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-3)//2-1))//2, odim)
x = self.out_norm(x)
x = self.out_balancer(x)
return x
class RandomCombine(nn.Module):
"""
This module combines a list of Tensors, all with the same shape, to
produce a single output of that same shape which, in training time,
is a random combination of all the inputs; but which in test time
will be just the last input.
The idea is that the list of Tensors will be a list of outputs of multiple
conformer layers. This has a similar effect as iterated loss. (See:
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
NETWORKS).
"""
def __init__(
self,
num_inputs: int,
final_weight: float = 0.5,
pure_prob: float = 0.5,
stddev: float = 2.0,
) -> None:
"""
Args:
num_inputs:
The number of tensor inputs, which equals the number of layers'
outputs that are fed into this module. E.g. in an 18-layer neural
net if we output layers 16, 12, 18, num_inputs would be 3.
final_weight:
The amount of weight or probability we assign to the
final layer when randomly choosing layers or when choosing
continuous layer weights.
pure_prob:
The probability, on each frame, with which we choose
only a single layer to output (rather than an interpolation)
stddev:
A standard deviation that we add to log-probs for computing
randomized weights.
The method of choosing which layers, or combinations of layers, to use,
is conceptually as follows::
With probability `pure_prob`::
With probability `final_weight`: choose final layer,
Else: choose random non-final layer.
Else::
Choose initial log-weights that correspond to assigning
weight `final_weight` to the final layer and equal
weights to other layers; then add Gaussian noise
with variance `stddev` to these log-weights, and normalize
to weights (note: the average weight assigned to the
final layer here will not be `final_weight` if stddev>0).
"""
super().__init__()
assert 0 <= pure_prob <= 1, pure_prob
assert 0 < final_weight < 1, final_weight
assert num_inputs >= 1
self.num_inputs = num_inputs
self.final_weight = final_weight
self.pure_prob = pure_prob
self.stddev = stddev
self.final_log_weight = (
torch.tensor(
(final_weight / (1 - final_weight)) * (self.num_inputs - 1)
)
.log()
.item()
)
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
"""Forward function.
Args:
inputs:
A list of Tensor, e.g. from various layers of a transformer.
All must be the same shape, of (*, num_channels)
Returns:
A Tensor of shape (*, num_channels). In test mode
this is just the final input.
"""
num_inputs = self.num_inputs
assert len(inputs) == num_inputs
if not self.training or torch.jit.is_scripting():
return inputs[-1]
# Shape of weights: (*, num_inputs)
num_channels = inputs[0].shape[-1]
num_frames = inputs[0].numel() // num_channels
ndim = inputs[0].ndim
# stacked_inputs: (num_frames, num_channels, num_inputs)
stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
(num_frames, num_channels, num_inputs)
)
# weights: (num_frames, num_inputs)
weights = self._get_random_weights(
inputs[0].dtype, inputs[0].device, num_frames
)
weights = weights.reshape(num_frames, num_inputs, 1)
# ans: (num_frames, num_channels, 1)
ans = torch.matmul(stacked_inputs, weights)
# ans: (*, num_channels)
ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,))
# The following if causes errors for torch script in torch 1.6.0
# if __name__ == "__main__":
# # for testing only...
# print("Weights = ", weights.reshape(num_frames, num_inputs))
return ans
def _get_random_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
) -> torch.Tensor:
"""Return a tensor of random weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired
Returns:
A tensor of shape (num_frames, self.num_inputs), such that
`ans.sum(dim=1)` is all ones.
"""
pure_prob = self.pure_prob
if pure_prob == 0.0:
return self._get_random_mixed_weights(dtype, device, num_frames)
elif pure_prob == 1.0:
return self._get_random_pure_weights(dtype, device, num_frames)
else:
p = self._get_random_pure_weights(dtype, device, num_frames)
m = self._get_random_mixed_weights(dtype, device, num_frames)
return torch.where(
torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m
)
def _get_random_pure_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
):
"""Return a tensor of random one-hot weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired.
Returns:
A one-hot tensor of shape `(num_frames, self.num_inputs)`, with
exactly one weight equal to 1.0 on each frame.
"""
final_prob = self.final_weight
# final contains self.num_inputs - 1 in all elements
final = torch.full((num_frames,), self.num_inputs - 1, device=device)
# nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa
nonfinal = torch.randint(
self.num_inputs - 1, (num_frames,), device=device
)
indexes = torch.where(
torch.rand(num_frames, device=device) < final_prob, final, nonfinal
)
ans = torch.nn.functional.one_hot(
indexes, num_classes=self.num_inputs
).to(dtype=dtype)
return ans
def _get_random_mixed_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
):
"""Return a tensor of random one-hot weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired.
Returns:
A tensor of shape (num_frames, self.num_inputs), which elements
in [0..1] that sum to one over the second axis, i.e.
`ans.sum(dim=1)` is all ones.
"""
logprobs = (
torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device)
* self.stddev
)
logprobs[:, -1] += self.final_log_weight
return logprobs.softmax(dim=1)
def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
print(
f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" # noqa
)
num_inputs = 3
num_channels = 50
m = RandomCombine(
num_inputs=num_inputs,
final_weight=final_weight,
pure_prob=pure_prob,
stddev=stddev,
)
x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
y = m(x)
assert y.shape == x[0].shape
assert torch.allclose(y, x[0]) # .. since actually all ones.
def _test_random_combine_main():
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.0)
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.3)
_test_random_combine(0.5, 1, 0.3)
_test_random_combine(0.5, 0.5, 0.3)
feature_dim = 50
c = RNN(num_features=feature_dim, d_model=128)
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs.
f = c(
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
)
f # to remove flake8 warnings
if __name__ == "__main__":
feature_dim = 80
m = RNN(
num_features=feature_dim,
d_model=512,
rnn_hidden_size=1024,
dim_feedforward=2048,
num_encoder_layers=12,
)
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs.
f = m(
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
warmup=0.5,
)
num_param = sum([p.numel() for p in m.parameters()])
print(f"Number of model parameters: {num_param}")
_test_random_combine_main()

View File

@ -0,0 +1 @@
../lstm_transducer_stateless/model.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1,352 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./lstm_transducer_stateless3/exp/epoch-xx.pt`.
Note: ./lstm_transducer_stateless3/exp/pretrained.pt is generated by
./lstm_transducer_stateless3/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens, _ = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/scaling.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless3/scaling_converter.py

View File

@ -0,0 +1 @@
../lstm_transducer_stateless/stream.py

View File

@ -0,0 +1,968 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./lstm_transducer_stateless3/streaming_decode.py \
--epoch 40 \
--avg 20 \
--exp-dir lstm_transducer_stateless3/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method greedy_search \
--use-averaged-model True
(2) modified beam search
./lstm_transducer_stateless3/streaming_decode.py \
--epoch 40 \
--avg 20 \
--exp-dir lstm_transducer_stateless3/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method modified_beam_search \
--use-averaged-model True \
--beam-size 4
(3) fast beam search
./lstm_transducer_stateless3/streaming_decode.py \
--epoch 40 \
--avg 20 \
--exp-dir lstm_transducer_stateless3/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method fast_beam_search \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from lstm import LOG_EPSILON, stack_states, unstack_states
from stream import Stream
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import one_best_decoding
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=40,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=20,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="lstm_transducer_stateless3/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--sampling-rate",
type=float,
default=16000,
help="Sample rate of the audio",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded in parallel",
)
add_model_arguments(parser)
return parser
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[Stream],
) -> None:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
streams:
A list of Stream objects.
"""
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = next(model.parameters()).device
T = encoder_out.size(1)
encoder_out = model.joiner.encoder_proj(encoder_out)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (batch_size, 1, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
decoder_out = model.joiner.decoder_proj(decoder_out)
def modified_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[Stream],
beam: int = 4,
):
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The RNN-T model.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
streams:
A list of stream objects.
beam:
Number of active paths during the beam search.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size = len(streams)
T = encoder_out.size(1)
B = [stream.hyps for stream in streams]
encoder_out = model.joiner.encoder_proj(encoder_out)
for t in range(T):
current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.stack(
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, encoder_out_dim)
logits = model.joiner(
current_encoder_out, decoder_out, project_input=False
)
# logits is of shape (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
for i in range(batch_size):
streams[i].hyps = B[i]
def fast_beam_search_one_best(
model: nn.Module,
streams: List[Stream],
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
) -> None:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
the shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
streams:
A list of stream objects.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
processed_lens:
A tensor of shape (N,) containing the number of processed frames
in `encoder_out` before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
"""
assert encoder_out.ndim == 3
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
B, T, C = encoder_out.shape
assert B == len(streams)
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(streams[i].rnnt_decoding_stream)
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
encoder_out = model.joiner.encoder_proj(encoder_out)
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
for i in range(B):
streams[i].hyp = hyps[i]
def decode_one_chunk(
model: nn.Module,
streams: List[Stream],
params: AttributeDict,
decoding_graph: Optional[k2.Fsa] = None,
) -> List[int]:
"""
Args:
model:
The Transducer model.
streams:
A list of Stream objects.
params:
It is returned by :func:`get_params`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search.
Returns:
A list of indexes indicating the finished streams.
"""
device = next(model.parameters()).device
feature_list = []
feature_len_list = []
state_list = []
num_processed_frames_list = []
for stream in streams:
# We should first get `stream.num_processed_frames`
# before calling `stream.get_feature_chunk()`
# since `stream.num_processed_frames` would be updated
num_processed_frames_list.append(stream.num_processed_frames)
feature = stream.get_feature_chunk()
feature_len = feature.size(0)
feature_list.append(feature)
feature_len_list.append(feature_len)
state_list.append(stream.states)
features = pad_sequence(
feature_list, batch_first=True, padding_value=LOG_EPSILON
).to(device)
feature_lens = torch.tensor(feature_len_list, device=device)
num_processed_frames = torch.tensor(
num_processed_frames_list, device=device
)
# Make sure it has at least 1 frame after subsampling
tail_length = params.subsampling_factor + 5
if features.size(1) < tail_length:
pad_length = tail_length - features.size(1)
feature_lens += pad_length
features = torch.nn.functional.pad(
features,
(0, 0, 0, pad_length),
mode="constant",
value=LOG_EPSILON,
)
# Stack states of all streams
states = stack_states(state_list)
encoder_out, encoder_out_lens, states = model.encoder(
x=features,
x_lens=feature_lens,
states=states,
)
if params.decoding_method == "greedy_search":
greedy_search(
model=model,
streams=streams,
encoder_out=encoder_out,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=streams,
encoder_out=encoder_out,
beam=params.beam_size,
)
elif params.decoding_method == "fast_beam_search":
# feature_len is needed to get partial results.
# The rnnt_decoding_stream for fast_beam_search.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
processed_lens = (
num_processed_frames // params.subsampling_factor
+ encoder_out_lens
)
fast_beam_search_one_best(
model=model,
streams=streams,
encoder_out=encoder_out,
processed_lens=processed_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
# Update cached states of each stream
state_list = unstack_states(states)
for i, s in enumerate(state_list):
streams[i].states = s
finished_streams = [i for i, stream in enumerate(streams) if stream.done]
return finished_streams
def create_streaming_feature_extractor() -> Fbank:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return Fbank(opts)
def decode_dataset(
cuts: CutSet,
model: nn.Module,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
):
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The Transducer model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = next(model.parameters()).device
log_interval = 300
fbank = create_streaming_feature_extractor()
decode_results = []
streams = []
for num, cut in enumerate(cuts):
# Each utterance has a Stream.
stream = Stream(
params=params,
cut_id=cut.id,
decoding_graph=decoding_graph,
device=device,
LOG_EPS=LOG_EPSILON,
)
stream.states = model.encoder.get_init_states(device=device)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
feature = fbank(samples)
stream.set_feature(feature)
stream.ground_truth = cut.supervisions[0].text
streams.append(stream)
while len(streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
model=model,
streams=streams,
params=params,
decoding_graph=decoding_graph,
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
streams[i].id,
streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(),
)
)
del streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
while len(streams) > 0:
finished_streams = decode_one_chunk(
model=model,
streams=streams,
params=params,
decoding_graph=decoding_graph,
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
streams[i].id,
streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(),
)
)
del streams[i]
if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
else:
key = f"beam_size_{params.beam_size}"
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=sorted(results))
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-streaming-decode")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.device = device
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval()
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_sets = ["test-clean", "test-other"]
test_cuts = [test_clean_cuts, test_other_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
model=model,
params=params,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
torch.manual_seed(20220810)
main()

View File

@ -0,0 +1,92 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./lstm_transducer_stateless/test_model.py
"""
import os
from pathlib import Path
import torch
from export import (
export_decoder_model_jit_trace,
export_encoder_model_jit_trace,
export_joiner_model_jit_trace,
)
from lstm import stack_states, unstack_states
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
params.encoder_dim = 512
params.rnn_hidden_size = 1024
params.num_encoder_layers = 12
params.aux_layer_period = 0
params.exp_dir = Path("exp_test_model")
model = get_transducer_model(params)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
convert_scaled_to_non_scaled(model, inplace=True)
if not os.path.exists(params.exp_dir):
os.path.mkdir(params.exp_dir)
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
print("The model has been successfully exported using jit.trace.")
def test_states_stack_and_unstack():
layer, batch, hidden, cell = 12, 100, 512, 1024
states = (
torch.randn(layer, batch, hidden),
torch.randn(layer, batch, cell),
)
states2 = stack_states(unstack_states(states))
assert torch.allclose(states[0], states2[0])
assert torch.allclose(states[1], states2[1])
def main():
test_model()
test_states_stack_and_unstack()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,257 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./lstm_transducer_stateless/test_scaling_converter.py
"""
import copy
import torch
from scaling import (
ScaledConv1d,
ScaledConv2d,
ScaledEmbedding,
ScaledLinear,
ScaledLSTM,
)
from scaling_converter import (
convert_scaled_to_non_scaled,
scaled_conv1d_to_conv1d,
scaled_conv2d_to_conv2d,
scaled_embedding_to_embedding,
scaled_linear_to_linear,
scaled_lstm_to_lstm,
)
from train import get_params, get_transducer_model
def get_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
params.encoder_dim = 512
params.rnn_hidden_size = 1024
params.num_encoder_layers = 12
params.aux_layer_period = -1
model = get_transducer_model(params)
return model
def test_scaled_linear_to_linear():
N = 5
in_features = 10
out_features = 20
for bias in [True, False]:
scaled_linear = ScaledLinear(
in_features=in_features,
out_features=out_features,
bias=bias,
)
linear = scaled_linear_to_linear(scaled_linear)
x = torch.rand(N, in_features)
y1 = scaled_linear(x)
y2 = linear(x)
assert torch.allclose(y1, y2)
jit_scaled_linear = torch.jit.script(scaled_linear)
jit_linear = torch.jit.script(linear)
y3 = jit_scaled_linear(x)
y4 = jit_linear(x)
assert torch.allclose(y3, y4)
assert torch.allclose(y1, y4)
def test_scaled_conv1d_to_conv1d():
in_channels = 3
for bias in [True, False]:
scaled_conv1d = ScaledConv1d(
in_channels,
6,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
conv1d = scaled_conv1d_to_conv1d(scaled_conv1d)
x = torch.rand(20, in_channels, 10)
y1 = scaled_conv1d(x)
y2 = conv1d(x)
assert torch.allclose(y1, y2)
jit_scaled_conv1d = torch.jit.script(scaled_conv1d)
jit_conv1d = torch.jit.script(conv1d)
y3 = jit_scaled_conv1d(x)
y4 = jit_conv1d(x)
assert torch.allclose(y3, y4)
assert torch.allclose(y1, y4)
def test_scaled_conv2d_to_conv2d():
in_channels = 1
for bias in [True, False]:
scaled_conv2d = ScaledConv2d(
in_channels=in_channels,
out_channels=3,
kernel_size=3,
padding=1,
bias=bias,
)
conv2d = scaled_conv2d_to_conv2d(scaled_conv2d)
x = torch.rand(20, in_channels, 10, 20)
y1 = scaled_conv2d(x)
y2 = conv2d(x)
assert torch.allclose(y1, y2)
jit_scaled_conv2d = torch.jit.script(scaled_conv2d)
jit_conv2d = torch.jit.script(conv2d)
y3 = jit_scaled_conv2d(x)
y4 = jit_conv2d(x)
assert torch.allclose(y3, y4)
assert torch.allclose(y1, y4)
def test_scaled_embedding_to_embedding():
scaled_embedding = ScaledEmbedding(
num_embeddings=500,
embedding_dim=10,
padding_idx=0,
)
embedding = scaled_embedding_to_embedding(scaled_embedding)
for s in [10, 100, 300, 500, 800, 1000]:
x = torch.randint(low=0, high=500, size=(s,))
scaled_y = scaled_embedding(x)
y = embedding(x)
assert torch.equal(scaled_y, y)
def test_scaled_lstm_to_lstm():
input_size = 512
batch_size = 20
for bias in [True, False]:
for hidden_size in [512, 1024]:
scaled_lstm = ScaledLSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
bias=bias,
proj_size=0 if hidden_size == input_size else input_size,
)
lstm = scaled_lstm_to_lstm(scaled_lstm)
x = torch.rand(200, batch_size, input_size)
h0 = torch.randn(1, batch_size, input_size)
c0 = torch.randn(1, batch_size, hidden_size)
y1, (h1, c1) = scaled_lstm(x, (h0, c0))
y2, (h2, c2) = lstm(x, (h0, c0))
assert torch.allclose(y1, y2)
assert torch.allclose(h1, h2)
assert torch.allclose(c1, c2)
jit_scaled_lstm = torch.jit.trace(lstm, (x, (h0, c0)))
y3, (h3, c3) = jit_scaled_lstm(x, (h0, c0))
assert torch.allclose(y1, y3)
assert torch.allclose(h1, h3)
assert torch.allclose(c1, c3)
def test_convert_scaled_to_non_scaled():
for inplace in [False, True]:
model = get_model()
model.eval()
orig_model = copy.deepcopy(model)
converted_model = convert_scaled_to_non_scaled(model, inplace=inplace)
model = orig_model
# test encoder
N = 2
T = 100
vocab_size = model.decoder.vocab_size
x = torch.randn(N, T, 80, dtype=torch.float32)
x_lens = torch.full((N,), x.size(1))
e1, e1_lens, _ = model.encoder(x, x_lens)
e2, e2_lens, _ = converted_model.encoder(x, x_lens)
assert torch.all(torch.eq(e1_lens, e2_lens))
assert torch.allclose(e1, e2), (e1 - e2).abs().max()
# test decoder
U = 50
y = torch.randint(low=1, high=vocab_size - 1, size=(N, U))
d1 = model.decoder(y)
d2 = model.decoder(y)
assert torch.allclose(d1, d2)
# test simple projection
lm1 = model.simple_lm_proj(d1)
am1 = model.simple_am_proj(e1)
lm2 = converted_model.simple_lm_proj(d2)
am2 = converted_model.simple_am_proj(e2)
assert torch.allclose(lm1, lm2)
assert torch.allclose(am1, am2)
# test joiner
e = torch.rand(2, 3, 4, 512)
d = torch.rand(2, 3, 4, 512)
j1 = model.joiner(e, d)
j2 = converted_model.joiner(e, d)
assert torch.allclose(j1, j2)
@torch.no_grad()
def main():
test_scaled_linear_to_linear()
test_scaled_conv1d_to_conv1d()
test_scaled_conv2d_to_conv2d()
test_scaled_embedding_to_embedding()
test_scaled_lstm_to_lstm()
test_convert_scaled_to_non_scaled()
if __name__ == "__main__":
torch.manual_seed(20220730)
main()

File diff suppressed because it is too large Load Diff

View File

@ -527,7 +527,9 @@ class ConformerEncoderLayer(nn.Module):
src = src + self.dropout(src_att)
# convolution module
conv, _ = self.conv_module(src)
conv, _ = self.conv_module(
src, src_key_padding_mask=src_key_padding_mask
)
src = src + self.dropout(conv)
# feed forward module
@ -1457,6 +1459,7 @@ class ConvolutionModule(nn.Module):
x: Tensor,
cache: Optional[Tensor] = None,
right_context: int = 0,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""Compute convolution module.
@ -1467,6 +1470,7 @@ class ConvolutionModule(nn.Module):
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
src_key_padding_mask: the mask for the src keys per batch (optional).
of right context, some have more.
Returns:
@ -1486,6 +1490,8 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
if self.causal and self.lorder > 0:
if cache is None:
# Make depthwise_conv causal by

View File

@ -164,6 +164,10 @@ class Eve(Optimizer):
p.mul_(1 - (weight_decay * is_above_target_rms))
p.addcdiv_(exp_avg, denom, value=-step_size)
# Constrain the range of scalar weights
if p.numel() == 1:
p.clamp_(min=-10, max=2)
return loss

View File

@ -111,6 +111,76 @@ class ActivationBalancerFunction(torch.autograd.Function):
return x_grad - neg_delta_grad, None, None, None, None, None, None
class GradientFilterFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
batch_dim: int, # e.g., 1
threshold: float, # e.g., 10.0
*params: Tensor, # module parameters
) -> Tuple[Tensor, ...]:
if x.requires_grad:
if batch_dim < 0:
batch_dim += x.ndim
ctx.batch_dim = batch_dim
ctx.threshold = threshold
return (x,) + params
@staticmethod
def backward(
ctx,
x_grad: Tensor,
*param_grads: Tensor,
) -> Tuple[Tensor, ...]:
eps = 1.0e-20
dim = ctx.batch_dim
norm_dims = [d for d in range(x_grad.ndim) if d != dim]
norm_of_batch = (x_grad ** 2).mean(dim=norm_dims, keepdim=True).sqrt()
median_norm = norm_of_batch.median()
cutoff = median_norm * ctx.threshold
inv_mask = (cutoff + norm_of_batch) / (cutoff + eps)
mask = 1.0 / (inv_mask + eps)
x_grad = x_grad * mask
avg_mask = 1.0 / (inv_mask.mean() + eps)
param_grads = [avg_mask * g for g in param_grads]
return (x_grad, None, None) + tuple(param_grads)
class GradientFilter(torch.nn.Module):
"""This is used to filter out elements that have extremely large gradients
in batch and the module parameters with soft masks.
Args:
batch_dim (int):
The batch dimension.
threshold (float):
For each element in batch, its gradient will be
filtered out if the gradient norm is larger than
`grad_norm_threshold * median`, where `median` is the median
value of gradient norms of all elememts in batch.
"""
def __init__(self, batch_dim: int = 1, threshold: float = 10.0):
super(GradientFilter, self).__init__()
self.batch_dim = batch_dim
self.threshold = threshold
def forward(self, x: Tensor, *params: Tensor) -> Tuple[Tensor, ...]:
if torch.jit.is_scripting() or is_jit_tracing():
return (x,) + params
else:
return GradientFilterFunction.apply(
x,
self.batch_dim,
self.threshold,
*params,
)
class BasicNorm(torch.nn.Module):
"""
This is intended to be a simpler, and hopefully cheaper, replacement for
@ -195,7 +265,7 @@ class ScaledLinear(nn.Linear):
*args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
**kwargs,
):
super(ScaledLinear, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
@ -242,7 +312,7 @@ class ScaledConv1d(nn.Conv1d):
*args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
**kwargs,
):
super(ScaledConv1d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
@ -314,7 +384,7 @@ class ScaledConv2d(nn.Conv2d):
*args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
**kwargs,
):
super(ScaledConv2d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
@ -389,7 +459,8 @@ class ScaledLSTM(nn.LSTM):
*args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
grad_norm_threshold: float = 10.0,
**kwargs,
):
if "bidirectional" in kwargs:
assert kwargs["bidirectional"] is False
@ -404,6 +475,10 @@ class ScaledLSTM(nn.LSTM):
setattr(self, scale_name, param)
self._scales.append(param)
self.grad_filter = GradientFilter(
batch_dim=1, threshold=grad_norm_threshold
)
self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in base class
@ -513,10 +588,14 @@ class ScaledLSTM(nn.LSTM):
hx = (h_zeros, c_zeros)
self.check_forward_args(input, hx, None)
flat_weights = self._get_flat_weights()
input, *flat_weights = self.grad_filter(input, *flat_weights)
result = _VF.lstm(
input,
hx,
self._get_flat_weights(),
flat_weights,
self.bias,
self.num_layers,
self.dropout,
@ -891,9 +970,54 @@ def _test_scaled_lstm():
assert c.shape == (1, N, dim_hidden)
def _test_grad_filter():
threshold = 50.0
time, batch, channel = 200, 5, 128
grad_filter = GradientFilter(batch_dim=1, threshold=threshold)
for i in range(2):
x = torch.randn(time, batch, channel, requires_grad=True)
w = nn.Parameter(torch.ones(5))
b = nn.Parameter(torch.zeros(5))
x_out, w_out, b_out = grad_filter(x, w, b)
w_out_grad = torch.randn_like(w)
b_out_grad = torch.randn_like(b)
x_out_grad = torch.rand_like(x)
if i % 2 == 1:
# The gradient norm of the first element must be larger than
# `threshold * median`, where `median` is the median value
# of gradient norms of all elements in batch.
x_out_grad[:, 0, :] = torch.full((time, channel), threshold)
torch.autograd.backward(
[x_out, w_out, b_out], [x_out_grad, w_out_grad, b_out_grad]
)
print(
"_test_grad_filter: for gradient norms, the first element > median * threshold ", # noqa
i % 2 == 1,
)
print(
"_test_grad_filter: x_out_grad norm = ",
(x_out_grad ** 2).mean(dim=(0, 2)).sqrt(),
)
print(
"_test_grad_filter: x.grad norm = ",
(x.grad ** 2).mean(dim=(0, 2)).sqrt(),
)
print("_test_grad_filter: w_out_grad = ", w_out_grad)
print("_test_grad_filter: w.grad = ", w.grad)
print("_test_grad_filter: b_out_grad = ", b_out_grad)
print("_test_grad_filter: b.grad = ", b.grad)
if __name__ == "__main__":
_test_activation_balancer_sign()
_test_activation_balancer_magnitude()
_test_basic_norm()
_test_double_swish_deriv()
_test_scaled_lstm()
_test_grad_filter()

View File

@ -30,6 +30,7 @@ from typing import List
import torch
import torch.nn as nn
from scaling import (
BasicNorm,
ScaledConv1d,
ScaledConv2d,
ScaledEmbedding,
@ -38,6 +39,29 @@ from scaling import (
)
class NonScaledNorm(nn.Module):
"""See BasicNorm for doc"""
def __init__(
self,
num_channels: int,
eps_exp: float,
channel_dim: int = -1, # CAUTION: see documentation.
):
super().__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
self.eps_exp = eps_exp
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not torch.jit.is_tracing():
assert x.shape[self.channel_dim] == self.num_channels
scales = (
torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp
).pow(-0.5)
return x * scales
def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
"""Convert an instance of ScaledLinear to nn.Linear.
@ -60,7 +84,7 @@ def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
linear = torch.nn.Linear(
in_features=scaled_linear.in_features,
out_features=scaled_linear.out_features,
bias=True, # otherwise, it throws errors when converting to PNNX format.
bias=True, # otherwise, it throws errors when converting to PNNX format
# device=weight.device, # Pytorch version before v1.9.0 does not has
# this argument. Comment out for now, we will
# see if it will raise error for versions
@ -174,6 +198,16 @@ def scaled_embedding_to_embedding(
return embedding
def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
assert isinstance(basic_norm, BasicNorm), type(BasicNorm)
norm = NonScaledNorm(
num_channels=basic_norm.num_channels,
eps_exp=basic_norm.eps.data.exp().item(),
channel_dim=basic_norm.channel_dim,
)
return norm
def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM:
"""Convert an instance of ScaledLSTM to nn.LSTM.
@ -206,7 +240,7 @@ def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM:
return lstm
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
# get_submodule was added to nn.Module at v1.9.0
def get_submodule(model, target):
if target == "":
@ -256,6 +290,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
d[name] = scaled_conv2d_to_conv2d(m)
elif isinstance(m, ScaledEmbedding):
d[name] = scaled_embedding_to_embedding(m)
elif isinstance(m, BasicNorm):
d[name] = convert_basic_norm(m)
elif isinstance(m, ScaledLSTM):
d[name] = scaled_lstm_to_lstm(m)

View File

@ -181,7 +181,7 @@ def test_convert_scaled_to_non_scaled():
y = torch.randint(low=1, high=vocab_size - 1, size=(N, U))
d1 = model.decoder(y)
d2 = model.decoder(y)
d2 = converted_model.decoder(y)
assert torch.allclose(d1, d2)

View File

@ -39,7 +39,7 @@ cd egs/librispeech/ASR/
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--use_fp16 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless3/exp \
--full-libri 1 \
--max-duration 550

View File

@ -527,7 +527,9 @@ class ConformerEncoderLayer(nn.Module):
src = src + self.dropout(src_att)
# convolution module
conv, _ = self.conv_module(src)
conv, _ = self.conv_module(
src, src_key_padding_mask=src_key_padding_mask
)
src = src + self.dropout(conv)
# feed forward module
@ -1436,7 +1438,11 @@ class ConvolutionModule(nn.Module):
)
def forward(
self, x: Tensor, cache: Optional[Tensor] = None, right_context: int = 0
self,
x: Tensor,
cache: Optional[Tensor] = None,
right_context: int = 0,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""Compute convolution module.
@ -1448,6 +1454,7 @@ class ConvolutionModule(nn.Module):
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
@ -1466,6 +1473,8 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
if self.causal and self.lorder > 0:
if cache is None:
# Make depthwise_conv causal by

View File

@ -264,7 +264,9 @@ class ConformerEncoderLayer(nn.Module):
src = src + self.dropout(src_att)
# convolution module
src = src + self.dropout(self.conv_module(src))
src = src + self.dropout(
self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
)
# feed forward module
src = src + self.dropout(self.feed_forward(src))
@ -927,11 +929,16 @@ class ConvolutionModule(nn.Module):
initial_scale=0.25,
)
def forward(self, x: Tensor) -> Tensor:
def forward(
self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
@ -947,6 +954,8 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
x = self.deriv_balancer2(x)

View File

@ -81,18 +81,17 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
# hyps is a list, every element is decode result of a sentence.
hyps = hubert_model.ctc_greedy_search(batch)
texts = batch["supervisions"]["text"]
assert len(hyps) == len(texts)
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
this_batch = []
for hyp_text, ref_text in zip(hyps, texts):
assert len(hyps) == len(texts)
for cut_id, hyp_text, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
hyp_words = hyp_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results["ctc_greedy_search"].extend(this_batch)
num_cuts += len(texts)

View File

@ -514,7 +514,9 @@ class ConformerEncoderLayer(nn.Module):
if self.normalize_before:
src = self.norm_conv(src)
src, _ = self.conv_module(src)
src, _ = self.conv_module(
src, src_key_padding_mask=src_key_padding_mask
)
src = residual + self.dropout(src)
if not self.normalize_before:
@ -1383,11 +1385,18 @@ class ConvolutionModule(nn.Module):
x: Tensor,
cache: Optional[Tensor] = None,
right_context: int = 0,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
cache: The cache of depthwise_conv, only used in real streaming
decoding.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
@ -1401,6 +1410,8 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
if self.causal and self.lorder > 0:
if cache is None:
# Make depthwise_conv causal by

View File

@ -69,6 +69,13 @@ def compute_fbank_musan():
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
musan_cuts_path = src_dir / "cuts_musan.jsonl.gz"
if musan_cuts_path.is_file():

View File

@ -62,6 +62,13 @@ def compute_fbank_tal_csasr(num_mel_bins: int = 80):
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.

Some files were not shown because too many files have changed in this diff Show More