apply black on all files

This commit is contained in:
Desh Raj 2022-11-17 09:42:17 -05:00
parent b3920e5ab5
commit 107df3b115
437 changed files with 3861 additions and 7334 deletions

View File

@ -45,17 +45,18 @@ jobs:
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
# See https://github.com/psf/black/issues/2964
# The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0
# Click issue fixed in https://github.com/psf/black/pull/2966
- name: Run flake8
shell: bash
working-directory: ${{github.workspace}}
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --show-source --statistics
flake8 .
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 \
--statistics --extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503
- name: Run black
shell: bash

View File

@ -1,26 +1,38 @@
repos:
- repo: https://github.com/psf/black
rev: 21.6b0
rev: 22.3.0
hooks:
- id: black
args: [--line-length=80]
additional_dependencies: ['click==8.0.1']
args: ["--line-length=88"]
additional_dependencies: ['click==8.1.0']
exclude: icefall\/__init__\.py
- repo: https://github.com/PyCQA/flake8
rev: 3.9.2
rev: 5.0.4
hooks:
- id: flake8
args: [--max-line-length=80]
args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"]
# What are we ignoring here?
# E203: whitespace before ':'
# E266: too many leading '#' for block comment
# E501: line too long
# F401: module imported but unused
# E402: module level import not at top of file
# F403: 'from module import *' used; unable to detect undefined names
# F841: local variable is assigned to but never used
# W503: line break before binary operator
# In addition, the default ignore list is:
# E121,E123,E126,E226,E24,E704,W503,W504
- repo: https://github.com/pycqa/isort
rev: 5.9.2
rev: 5.10.1
hooks:
- id: isort
args: [--profile=black, --line-length=80]
args: ["--profile=black"]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
rev: v4.2.0
hooks:
- id: check-executables-have-shebangs
- id: end-of-file-fixer

View File

@ -2,7 +2,7 @@
2 sets of configuration are provided - (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8, and (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8.
If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8.
If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8.
Otherwise, since the older PyTorch images are not updated with the [apt-key rotation by NVIDIA](https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key), you have to go for case (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. Ensure that your NVDIA driver supports at least CUDA 11.0.
@ -10,7 +10,7 @@ You can check the highest CUDA version within your NVIDIA driver's support with
```bash
$ nvidia-smi
Tue Sep 20 00:26:13 2022
Tue Sep 20 00:26:13 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.119.03 Driver Version: 450.119.03 CUDA Version: 11.0 |
|-------------------------------+----------------------+----------------------+
@ -26,7 +26,7 @@ Tue Sep 20 00:26:13 2022
| 41% 30C P8 11W / 280W | 6MiB / 24220MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
@ -40,15 +40,15 @@ Tue Sep 20 00:26:13 2022
```
## Building images locally
If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly.
For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details.
If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly.
For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details.
```dockerfile
ENV http_proxy=http://aaa.bb.cc.net:8080 \
https_proxy=http://aaa.bb.cc.net:8080
```
Then, proceed with these commands.
Then, proceed with these commands.
### If you are case (a), i.e. your NVIDIA driver supports CUDA version >= 11.3:
@ -72,11 +72,11 @@ docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all icefall
```
### Tips:
1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`.
1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`.
2. Also, if your environment requires a proxy, this would be a good time to add it in too: `-e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080`.
Overall, your docker run command should look like this.
Overall, your docker run command should look like this.
```bash
docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/host/machine}:{/path/in/docker} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1
@ -86,9 +86,9 @@ You can explore more docker run options [here](https://docs.docker.com/engine/re
### Linking to icefall in your host machine
If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container.
If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container.
Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine.
Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine.
Warning: Check that the icefall in your host machine is visible from within your container before proceeding to the commands below.
Use these commands once you are inside the container.
@ -103,7 +103,7 @@ ln -s {/path/in/docker/to/icefall} /workspace/icefall
docker exec -it icefall /bin/bash
```
## Restarting a killed container that has been run before.
## Restarting a killed container that has been run before.
```bash
docker start -ai icefall
```
@ -111,4 +111,4 @@ docker start -ai icefall
## Sample usage of the CPU based images:
```bash
docker run -it icefall /bin/bash
```
```

View File

@ -1,7 +1,7 @@
FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel
# ENV http_proxy=http://aaa.bbb.cc.net:8080 \
# https_proxy=http://aaa.bbb.cc.net:8080
# https_proxy=http://aaa.bbb.cc.net:8080
# install normal source
RUN apt-get update && \
@ -38,10 +38,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
rm -rf cmake-3.18.0.tar.gz && \
find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
cd -
# flac
# flac
RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \
cd /opt && \
cd /opt && \
xz -d flac-1.3.2.tar.xz && \
tar -xvf flac-1.3.2.tar && \
cd flac-1.3.2 && \
@ -49,11 +49,11 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz &&
make && make install && \
rm -rf flac-1.3.2.tar && \
find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
cd -
cd -
RUN conda install -y -c pytorch torchaudio=0.12 && \
pip install graphviz
#install k2 from source
RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
@ -68,7 +68,7 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install -r requirements.txt
RUN pip install kaldifeat
RUN pip install kaldifeat
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -1,12 +1,12 @@
FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel
# ENV http_proxy=http://aaa.bbb.cc.net:8080 \
# https_proxy=http://aaa.bbb.cc.net:8080
# https_proxy=http://aaa.bbb.cc.net:8080
RUN rm /etc/apt/sources.list.d/cuda.list && \
rm /etc/apt/sources.list.d/nvidia-ml.list && \
apt-key del 7fa2af80
# install normal source
RUN apt-get update && \
apt-get install -y --no-install-recommends \
@ -36,7 +36,7 @@ RUN curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu18
curl -fsSL https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \
echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \
echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \
rm -rf /var/lib/apt/lists/* && \
rm -rf /var/lib/apt/lists/* && \
mv /opt/conda/lib/libcufft.so.10 /opt/libcufft.so.10.bak && \
mv /opt/conda/lib/libcurand.so.10 /opt/libcurand.so.10.bak && \
mv /opt/conda/lib/libcublas.so.11 /opt/libcublas.so.11.bak && \
@ -56,10 +56,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
rm -rf cmake-3.18.0.tar.gz && \
find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
cd -
# flac
# flac
RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \
cd /opt && \
cd /opt && \
xz -d flac-1.3.2.tar.xz && \
tar -xvf flac-1.3.2.tar && \
cd flac-1.3.2 && \
@ -67,7 +67,7 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz &&
make && make install && \
rm -rf flac-1.3.2.tar && \
find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
cd -
cd -
RUN conda install -y -c pytorch torchaudio=0.7.1 && \
pip install graphviz
@ -79,7 +79,7 @@ RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
cd -
# install lhotse
RUN pip install git+https://github.com/lhotse-speech/lhotse
RUN pip install git+https://github.com/lhotse-speech/lhotse
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
@ -88,4 +88,3 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -1 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="80" height="20" role="img" aria-label="k2: &gt;= v1.9"><title>k2: &gt;= v1.9</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="80" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="23" height="20" fill="#555"/><rect x="23" width="57" height="20" fill="blueviolet"/><rect width="80" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="125" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="130">k2</text><text x="125" y="140" transform="scale(.1)" fill="#fff" textLength="130">k2</text><text aria-hidden="true" x="505" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="470">&gt;= v1.9</text><text x="505" y="140" transform="scale(.1)" fill="#fff" textLength="470">&gt;= v1.9</text></g></svg>
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="80" height="20" role="img" aria-label="k2: &gt;= v1.9"><title>k2: &gt;= v1.9</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="80" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="23" height="20" fill="#555"/><rect x="23" width="57" height="20" fill="blueviolet"/><rect width="80" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="125" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="130">k2</text><text x="125" y="140" transform="scale(.1)" fill="#fff" textLength="130">k2</text><text aria-hidden="true" x="505" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="470">&gt;= v1.9</text><text x="505" y="140" transform="scale(.1)" fill="#fff" textLength="470">&gt;= v1.9</text></g></svg>

Before

Width:  |  Height:  |  Size: 1.1 KiB

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -1 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="98" height="20" role="img" aria-label="python: &gt;= 3.6"><title>python: &gt;= 3.6</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="98" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="49" height="20" fill="#555"/><rect x="49" width="49" height="20" fill="#007ec6"/><rect width="98" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="255" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="390">python</text><text x="255" y="140" transform="scale(.1)" fill="#fff" textLength="390">python</text><text aria-hidden="true" x="725" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="390">&gt;= 3.6</text><text x="725" y="140" transform="scale(.1)" fill="#fff" textLength="390">&gt;= 3.6</text></g></svg>
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="98" height="20" role="img" aria-label="python: &gt;= 3.6"><title>python: &gt;= 3.6</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="98" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="49" height="20" fill="#555"/><rect x="49" width="49" height="20" fill="#007ec6"/><rect width="98" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="255" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="390">python</text><text x="255" y="140" transform="scale(.1)" fill="#fff" textLength="390">python</text><text aria-hidden="true" x="725" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="390">&gt;= 3.6</text><text x="725" y="140" transform="scale(.1)" fill="#fff" textLength="390">&gt;= 3.6</text></g></svg>

Before

Width:  |  Height:  |  Size: 1.1 KiB

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -1 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="100" height="20" role="img" aria-label="torch: &gt;= 1.6.0"><title>torch: &gt;= 1.6.0</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="100" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="39" height="20" fill="#555"/><rect x="39" width="61" height="20" fill="#97ca00"/><rect width="100" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="205" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="290">torch</text><text x="205" y="140" transform="scale(.1)" fill="#fff" textLength="290">torch</text><text aria-hidden="true" x="685" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="510">&gt;= 1.6.0</text><text x="685" y="140" transform="scale(.1)" fill="#fff" textLength="510">&gt;= 1.6.0</text></g></svg>
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="100" height="20" role="img" aria-label="torch: &gt;= 1.6.0"><title>torch: &gt;= 1.6.0</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="100" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="39" height="20" fill="#555"/><rect x="39" width="61" height="20" fill="#97ca00"/><rect width="100" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="205" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="290">torch</text><text x="205" y="140" transform="scale(.1)" fill="#fff" textLength="290">torch</text><text aria-hidden="true" x="685" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="510">&gt;= 1.6.0</text><text x="685" y="140" transform="scale(.1)" fill="#fff" textLength="510">&gt;= 1.6.0</text></g></svg>

Before

Width:  |  Height:  |  Size: 1.1 KiB

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -19,4 +19,3 @@ It can be downloaded from `<https://www.openslr.org/33/>`_
tdnn_lstm_ctc
conformer_ctc
stateless_transducer

View File

@ -6,4 +6,3 @@ TIMIT
tdnn_ligru_ctc
tdnn_lstm_ctc

View File

@ -148,10 +148,10 @@ Some commonly used options are:
$ ./tdnn_ligru_ctc/decode.py --epoch 25 --avg 17
uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``,
``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``,
``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``,
``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``,
uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``,
``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``,
``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``,
``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``,
``epoch-24.pt`` and ``epoch-25.pt``
for decoding.
@ -317,13 +317,13 @@ To decode with ``1best`` method, we can use:
.. code-block:: bash
./tdnn_ligru_ctc/pretrained.py
./tdnn_ligru_ctc/pretrained.py
--method 1best
--checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt
--words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt
--HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
--checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt
--words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt
--HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV
The output is:
@ -337,7 +337,7 @@ The output is:
2021-11-08 20:41:38,697 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV']
2021-11-08 20:41:38,704 INFO [pretrained.py:216] Decoding started
2021-11-08 20:41:39,819 INFO [pretrained.py:246] Use HLG decoding
2021-11-08 20:41:39,829 INFO [pretrained.py:267]
2021-11-08 20:41:39,829 INFO [pretrained.py:267]
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV:
sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx ih w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh
@ -362,8 +362,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use
--HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt \
--G ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 0.1 \
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV
The decoding output is:
@ -378,7 +378,7 @@ The decoding output is:
2021-11-08 20:37:54,715 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV']
2021-11-08 20:37:54,720 INFO [pretrained.py:216] Decoding started
2021-11-08 20:37:55,808 INFO [pretrained.py:251] Use HLG decoding + LM rescoring
2021-11-08 20:37:56,348 INFO [pretrained.py:267]
2021-11-08 20:37:56,348 INFO [pretrained.py:267]
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV:
sil dh ih sh uw ah l iy v iy z ah sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy ng w ih th ih n ih m s eh l f sil jh

View File

@ -148,8 +148,8 @@ Some commonly used options are:
$ ./tdnn_lstm_ctc/decode.py --epoch 25 --avg 10
uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``,
``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``,
uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``,
``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``,
``epoch-23.pt``, ``epoch-24.pt`` and ``epoch-25.pt``
for decoding.
@ -315,13 +315,13 @@ To decode with ``1best`` method, we can use:
.. code-block:: bash
./tdnn_lstm_ctc/pretrained.py
./tdnn_lstm_ctc/pretrained.py
--method 1best
--checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt
--words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt
--HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
--checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt
--words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt
--HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
The output is:
@ -335,7 +335,7 @@ The output is:
2021-11-08 21:02:53,827 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV']
2021-11-08 21:02:53,831 INFO [pretrained.py:216] Decoding started
2021-11-08 21:02:54,380 INFO [pretrained.py:246] Use HLG decoding
2021-11-08 21:02:54,387 INFO [pretrained.py:267]
2021-11-08 21:02:54,387 INFO [pretrained.py:267]
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV:
sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy w ih th ih n ih m s eh l f sil jh
@ -360,8 +360,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use
--HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt \
--G ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 0.08 \
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
The decoding output is:
@ -376,7 +376,7 @@ The decoding output is:
2021-11-08 20:05:26,978 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV']
2021-11-08 20:05:26,981 INFO [pretrained.py:216] Decoding started
2021-11-08 20:05:27,519 INFO [pretrained.py:251] Use HLG decoding + LM rescoring
2021-11-08 20:05:27,878 INFO [pretrained.py:267]
2021-11-08 20:05:27,878 INFO [pretrained.py:267]
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV:
sil dh ih sh uw l iy v iy z ih sil p r aa sil k s ah m ey dx ih sil w uh dx iy w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh

View File

@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
)
if "train" in partition:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
@ -116,9 +114,7 @@ def get_args():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state
word = word2id[word]
pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False
def generate_lexicon(
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:

View File

@ -317,9 +317,7 @@ def lexicon_to_fst(
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
return parser.parse_args()

View File

@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst(
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")

View File

@ -56,9 +56,7 @@ def get_parser():
parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
)
parser.add_argument(
"--space", default="<space>", type=str, help="space symbol"
)
parser.add_argument("--space", default="<space>", type=str, help="space symbol")
parser.add_argument(
"--non-lang-syms",
"-l",
@ -66,9 +64,7 @@ def get_parser():
type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
)
parser.add_argument(
"text", type=str, default=False, nargs="?", help="input text"
)
parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
parser.add_argument(
"--trans_type",
"-t",
@ -108,8 +104,7 @@ def token2id(
if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list)
sub_ids = [
token_table[txt] if txt in token_table else oov_id
for txt in text
token_table[txt] if txt in token_table else oov_id for txt in text
]
ids.append(sub_ids)
else: # token_type = "pinyin"
@ -135,9 +130,7 @@ def main():
if args.text:
f = codecs.open(args.text, encoding="utf-8")
else:
f = codecs.getreader("utf-8")(
sys.stdin if is_python2 else sys.stdin.buffer
)
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer

View File

@ -106,11 +106,10 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
if [ ! -f $lang_char_dir/words.txt ]; then
./local/prepare_words.py \
--input-file $lang_char_dir/words_no_ids.txt \
--output-file $lang_char_dir/words.txt
--output-file $lang_char_dir/words.txt
fi
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
./local/prepare_char.py
fi
fi

View File

@ -205,17 +205,13 @@ class Aidatatang_200zhAsrDataModule:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
@ -237,9 +233,7 @@ class Aidatatang_200zhAsrDataModule:
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
@ -282,9 +276,7 @@ class Aidatatang_200zhAsrDataModule:
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
@ -340,9 +332,7 @@ class Aidatatang_200zhAsrDataModule:
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:

View File

@ -69,11 +69,7 @@ from beam_search import (
)
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -192,8 +188,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -249,9 +244,7 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
if params.decoding_method == "fast_beam_search":
@ -266,10 +259,7 @@ def decode_one_batch(
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -390,9 +380,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@ -425,8 +413,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)

View File

@ -103,8 +103,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
return parser
@ -173,9 +172,7 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -162,8 +162,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
@ -194,8 +193,7 @@ def read_sound_files(
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
@ -257,9 +255,7 @@ def main():
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
@ -284,10 +280,7 @@ def main():
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -339,9 +332,7 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -81,9 +81,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
@ -187,8 +185,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
@ -211,8 +208,7 @@ def get_parser():
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
help="The scale to smooth the loss with am (output of encoder network)" "part.",
)
parser.add_argument(
@ -542,22 +538,15 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -711,9 +700,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@ -813,7 +800,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0
)
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm(
d_model
) # for the macaron style FNN module
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(
d_model
) # for the final output of the block
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
self.dropout = nn.Dropout(dropout)
@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
if not self.normalize_before:
src = self.norm_ff_macaron(src)
@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module):
"""
def __init__(
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vector and `j` means the
@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = nn.functional.linear(
query, in_proj_weight, in_proj_bias
).chunk(3, dim=-1)
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
3, dim=-1
)
elif torch.equal(key, value):
# encoder-decoder attention
@ -703,31 +691,22 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError(
"The size of the 2D attn_mask is not correct."
)
raise RuntimeError("The size of the 2D attn_mask is not correct.")
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
query.size(0),
key.size(0),
]:
raise RuntimeError(
"The size of the 3D attn_mask is not correct."
)
raise RuntimeError("The size of the 3D attn_mask is not correct.")
else:
raise RuntimeError(
"attn_mask's dimension {} is not supported".format(
attn_mask.dim()
)
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
)
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if (
key_padding_mask is not None
and key_padding_mask.dtype == torch.uint8
):
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
)
@ -766,9 +745,7 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
# compute matrix b and matrix d
matrix_bd = torch.matmul(
@ -780,9 +757,7 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1
)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
assert list(attn_output_weights.size()) == [
bsz * num_heads,
@ -816,13 +791,9 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = (
attn_output.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
@ -845,9 +816,7 @@ class ConvolutionModule(nn.Module):
"""
def __init__(
self, channels: int, kernel_size: int, bias: bool = True
) -> None:
def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding

View File

@ -401,9 +401,7 @@ def decode_dataset(
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@ -431,9 +429,7 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@ -441,9 +437,7 @@ def save_results(
test_set_wers[key] = wer
if enable_log:
logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
@ -562,9 +556,7 @@ def main():
eos_id=eos_id,
)
save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
logging.info("Done!")

View File

@ -157,9 +157,7 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -211,8 +211,7 @@ def read_sound_files(
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
@ -274,9 +273,7 @@ def main():
logging.info("Decoding started")
features = fbank(waves)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
# Note: We don't use key padding mask for attention during decoding
with torch.no_grad():
@ -371,9 +368,7 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module):
assert idim >= 7
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2
),
nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
nn.ReLU(),
nn.Conv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
nn.ReLU(),
)
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@ -132,17 +128,13 @@ class VggSubsampling(nn.Module):
)
)
layers.append(
torch.nn.MaxPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True
)
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
)
cur_channels = block_dim
self.layers = nn.Sequential(*layers)
self.out = nn.Linear(
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
)
self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.

View File

@ -16,9 +16,8 @@
# limitations under the License.
from subsampling import Conv2dSubsampling
from subsampling import VggSubsampling
import torch
from subsampling import Conv2dSubsampling, VggSubsampling
def test_conv2d_subsampling():

View File

@ -382,9 +382,7 @@ def compute_loss(
#
# See https://github.com/k2-fsa/icefall/issues/97
# for more details
unsorted_token_ids = graph_compiler.texts_to_ids(
supervisions["text"]
)
unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
att_loss = mmodel.decoder_forward(
encoder_memory,
memory_mask,
@ -520,9 +518,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@ -630,9 +626,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:

View File

@ -149,9 +149,7 @@ class Transformer(nn.Module):
norm=decoder_norm,
)
self.decoder_output_layer = torch.nn.Linear(
d_model, self.decoder_num_class
)
self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
self.decoder_criterion = LabelSmoothingLoss()
else:
@ -183,9 +181,7 @@ class Transformer(nn.Module):
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision
)
encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
x = self.ctc_output(encoder_memory)
return x, encoder_memory, memory_key_padding_mask
@ -266,23 +262,17 @@ class Transformer(nn.Module):
"""
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
device = memory.device
ys_in_pad = ys_in_pad.to(device)
ys_out_pad = ys_out_pad.to(device)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
device
)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
@ -343,23 +333,17 @@ class Transformer(nn.Module):
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
device
)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
@ -632,9 +616,7 @@ def _get_activation_fn(activation: str):
elif activation == "gelu":
return nn.functional.gelu
raise RuntimeError(
"activation should be relu/gelu, not {}".format(activation)
)
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
class PositionalEncoding(nn.Module):
@ -836,9 +818,7 @@ def encoder_padding_mask(
1,
).to(torch.int32)
lengths = [
0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
]
lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
for idx in range(supervision_segments.size(0)):
# Note: TorchScript doesn't allow to unpack tensors as tuples
sequence_idx = supervision_segments[idx, 0].item()
@ -859,9 +839,7 @@ def encoder_padding_mask(
return mask
def decoder_padding_mask(
ys_pad: torch.Tensor, ignore_id: int = -1
) -> torch.Tensor:
def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
"""Generate a length mask for input.
The masked position are filled with True,

View File

@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0
)
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm(
d_model
) # for the macaron style FNN module
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(
d_model
) # for the final output of the block
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
self.dropout = nn.Dropout(dropout)
@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
if not self.normalize_before:
src = self.norm_ff_macaron(src)
@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module):
"""
def __init__(
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vector and `j` means the
@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = nn.functional.linear(
query, in_proj_weight, in_proj_bias
).chunk(3, dim=-1)
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
3, dim=-1
)
elif torch.equal(key, value):
# encoder-decoder attention
@ -703,31 +691,22 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError(
"The size of the 2D attn_mask is not correct."
)
raise RuntimeError("The size of the 2D attn_mask is not correct.")
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
query.size(0),
key.size(0),
]:
raise RuntimeError(
"The size of the 3D attn_mask is not correct."
)
raise RuntimeError("The size of the 3D attn_mask is not correct.")
else:
raise RuntimeError(
"attn_mask's dimension {} is not supported".format(
attn_mask.dim()
)
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
)
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if (
key_padding_mask is not None
and key_padding_mask.dtype == torch.uint8
):
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
)
@ -766,9 +745,7 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
# compute matrix b and matrix d
matrix_bd = torch.matmul(
@ -780,9 +757,7 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1
)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
assert list(attn_output_weights.size()) == [
bsz * num_heads,
@ -816,13 +791,9 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = (
attn_output.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
@ -845,9 +816,7 @@ class ConvolutionModule(nn.Module):
"""
def __init__(
self, channels: int, kernel_size: int, bias: bool = True
) -> None:
def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding

View File

@ -413,9 +413,7 @@ def decode_dataset(
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@ -443,9 +441,7 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@ -453,9 +449,7 @@ def save_results(
test_set_wers[key] = wer
if enable_log:
logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
@ -550,9 +544,7 @@ def main():
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
return
model.to(device)
@ -581,9 +573,7 @@ def main():
eos_id=eos_id,
)
save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
logging.info("Done!")

View File

@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module):
assert idim >= 7
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2
),
nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
nn.ReLU(),
nn.Conv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
nn.ReLU(),
)
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@ -132,17 +128,13 @@ class VggSubsampling(nn.Module):
)
)
layers.append(
torch.nn.MaxPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True
)
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
)
cur_channels = block_dim
self.layers = nn.Sequential(*layers)
self.out = nn.Linear(
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
)
self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.

View File

@ -511,9 +511,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@ -625,9 +623,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:

View File

@ -149,9 +149,7 @@ class Transformer(nn.Module):
norm=decoder_norm,
)
self.decoder_output_layer = torch.nn.Linear(
d_model, self.decoder_num_class
)
self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
self.decoder_criterion = LabelSmoothingLoss()
else:
@ -183,9 +181,7 @@ class Transformer(nn.Module):
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision
)
encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
x = self.ctc_output(encoder_memory)
return x, encoder_memory, memory_key_padding_mask
@ -266,23 +262,17 @@ class Transformer(nn.Module):
"""
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
device = memory.device
ys_in_pad = ys_in_pad.to(device)
ys_out_pad = ys_out_pad.to(device)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
device
)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
@ -343,23 +333,17 @@ class Transformer(nn.Module):
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
device
)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
@ -632,9 +616,7 @@ def _get_activation_fn(activation: str):
elif activation == "gelu":
return nn.functional.gelu
raise RuntimeError(
"activation should be relu/gelu, not {}".format(activation)
)
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
class PositionalEncoding(nn.Module):
@ -836,9 +818,7 @@ def encoder_padding_mask(
1,
).to(torch.int32)
lengths = [
0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
]
lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
for idx in range(supervision_segments.size(0)):
# Note: TorchScript doesn't allow to unpack tensors as tuples
sequence_idx = supervision_segments[idx, 0].item()
@ -859,9 +839,7 @@ def encoder_padding_mask(
return mask
def decoder_padding_mask(
ys_pad: torch.Tensor, ignore_id: int = -1
) -> torch.Tensor:
def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
"""Generate a length mask for input.
The masked position are filled with True,

View File

@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
)
if "train" in partition:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
@ -116,9 +114,7 @@ def get_args():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -83,9 +83,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
)
if "train" in partition:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
@ -111,9 +109,7 @@ def get_args():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state
word = word2id[word]
pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False
def generate_lexicon(
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:

View File

@ -317,9 +317,7 @@ def lexicon_to_fst(
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
return parser.parse_args()

View File

@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst(
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")

View File

@ -76,11 +76,7 @@ from beam_search import (
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -188,8 +184,7 @@ def get_parser():
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -249,9 +244,7 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
@ -263,10 +256,7 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -387,9 +377,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@ -415,9 +403,7 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -428,8 +414,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
@ -473,9 +458,7 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -504,8 +487,7 @@ def main():
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(

View File

@ -50,11 +50,7 @@ from pathlib import Path
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
@ -120,8 +116,7 @@ def get_parser():
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
add_model_arguments(parser)
@ -157,8 +152,7 @@ def main():
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
@ -191,9 +185,7 @@ def main():
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = (
params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
)
filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
@ -201,17 +193,14 @@ def main():
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = (
params.exp_dir
/ f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
)
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -165,8 +165,7 @@ def get_parser():
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -197,8 +196,7 @@ def read_sound_files(
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
@ -256,13 +254,9 @@ def main():
feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens
)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
num_waves = encoder_out.size(0)
hyp_list = []
@ -310,9 +304,7 @@ def main():
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.method}"
)
raise ValueError(f"Unsupported decoding method: {params.method}")
hyp_list.append(hyp)
hyps = []
@ -329,9 +321,7 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -49,7 +49,6 @@ import optim
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from conformer import Conformer
from decoder import Decoder
@ -75,9 +74,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def add_model_arguments(parser: argparse.ArgumentParser):
@ -203,8 +200,7 @@ def get_parser():
"--initial-lr",
type=float,
default=0.003,
help="The initial learning rate. This value should not need "
"to be changed.",
help="The initial learning rate. This value should not need " "to be changed.",
)
parser.add_argument(
@ -227,8 +223,7 @@ def get_parser():
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
@ -251,8 +246,7 @@ def get_parser():
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
help="The scale to smooth the loss with am (output of encoder network)" "part.",
)
parser.add_argument(
@ -561,11 +555,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@ -593,23 +583,16 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -725,9 +708,7 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
except: # noqa
display_and_save_batch(
batch, params=params, graph_compiler=graph_compiler
)
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
raise
if params.print_diagnostics and batch_idx == 5:
@ -891,7 +872,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
@ -1029,9 +1010,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(
batch, params=params, graph_compiler=graph_compiler
)
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
raise

View File

@ -202,8 +202,7 @@ def get_parser():
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -263,9 +262,7 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
@ -277,10 +274,7 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -401,9 +395,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@ -429,9 +421,7 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -442,8 +432,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
@ -488,9 +477,7 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -518,9 +505,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -551,9 +538,9 @@ def main():
)
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"

View File

@ -132,8 +132,7 @@ def get_parser():
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
add_model_arguments(parser)
@ -166,9 +165,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -195,9 +194,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -252,9 +251,7 @@ def main():
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = (
params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
)
filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
@ -262,17 +259,14 @@ def main():
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = (
params.exp_dir
/ f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
)
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -84,9 +84,7 @@ class Transducer(nn.Module):
self.decoder_datatang = decoder_datatang
self.joiner_datatang = joiner_datatang
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
if decoder_datatang is not None:
@ -179,9 +177,7 @@ class Transducer(nn.Module):
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens

View File

@ -165,8 +165,7 @@ def get_parser():
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -197,8 +196,7 @@ def read_sound_files(
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
@ -257,13 +255,9 @@ def main():
feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens
)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
num_waves = encoder_out.size(0)
hyp_list = []
@ -311,9 +305,7 @@ def main():
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.method}"
)
raise ValueError(f"Unsupported decoding method: {params.method}")
hyp_list.append(hyp)
hyps = []
@ -330,9 +322,7 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -96,9 +96,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def add_model_arguments(parser: argparse.ArgumentParser):
@ -224,8 +222,7 @@ def get_parser():
"--initial-lr",
type=float,
default=0.003,
help="The initial learning rate. This value should not need "
"to be changed.",
help="The initial learning rate. This value should not need " "to be changed.",
)
parser.add_argument(
@ -248,8 +245,7 @@ def get_parser():
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
@ -272,8 +268,7 @@ def get_parser():
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
help="The scale to smooth the loss with am (output of encoder network)" "part.",
)
parser.add_argument(
@ -635,11 +630,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@ -670,23 +661,16 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -824,9 +808,7 @@ def train_one_epoch(
)
# summary stats
if datatang_train_dl is not None:
tot_loss = (
tot_loss * (1 - 1 / params.reset_interval)
) + loss_info
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
if aishell:
aishell_tot_loss = (
@ -847,9 +829,7 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
except: # noqa
display_and_save_batch(
batch, params=params, graph_compiler=graph_compiler
)
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
raise
if params.print_diagnostics and batch_idx == 5:
@ -892,9 +872,7 @@ def train_one_epoch(
cur_lr = scheduler.get_last_lr()[0]
if datatang_train_dl is not None:
datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], "
tot_loss_str = (
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
)
tot_loss_str = f"tot_loss[{tot_loss}], batch size: {batch_size}, "
else:
tot_loss_str = ""
datatang_str = ""
@ -1067,7 +1045,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
@ -1076,9 +1054,7 @@ def run(rank, world_size, args):
train_cuts = filter_short_and_long_utterances(train_cuts)
if args.enable_musan:
cuts_musan = load_manifest(
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
)
cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
else:
cuts_musan = None
@ -1093,9 +1069,7 @@ def run(rank, world_size, args):
if params.datatang_prob > 0:
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
train_datatang_cuts = datatang.train_cuts()
train_datatang_cuts = filter_short_and_long_utterances(
train_datatang_cuts
)
train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts)
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
datatang_train_dl = asr_datamodule.train_dataloaders(
train_datatang_cuts,
@ -1249,9 +1223,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(
batch, params=params, graph_compiler=graph_compiler
)
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
raise

View File

@ -183,17 +183,13 @@ class AishellAsrDataModule:
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
@ -215,9 +211,7 @@ class AishellAsrDataModule:
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
@ -260,9 +254,7 @@ class AishellAsrDataModule:
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
@ -308,9 +300,7 @@ class AishellAsrDataModule:
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
@ -366,13 +356,9 @@ class AishellAsrDataModule:
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
return load_manifest_lazy(
self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz"
)
return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz")
@lru_cache()
def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "aishell_cuts_test.jsonl.gz"
)
return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz")

View File

@ -265,9 +265,7 @@ def decode_dataset(
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@ -289,9 +287,7 @@ def save_results(
# We compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results_char)
test_set_wers[key] = wer
@ -335,9 +331,7 @@ def main():
logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
HLG = HLG.to(device)
assert HLG.requires_grad is False
@ -362,9 +356,7 @@ def main():
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
model.to(device)
model.eval()
@ -392,9 +384,7 @@ def main():
lexicon=lexicon,
)
save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
logging.info("Done!")

View File

@ -66,10 +66,7 @@ class TdnnLstm(nn.Module):
nn.BatchNorm1d(num_features=500, affine=False),
)
self.lstms = nn.ModuleList(
[
nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
for _ in range(5)
]
[nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)]
)
self.lstm_bnorms = nn.ModuleList(
[nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]

View File

@ -53,9 +53,7 @@ def get_parser():
help="Path to words.txt",
)
parser.add_argument(
"--HLG", type=str, required=True, help="Path to HLG.pt."
)
parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
parser.add_argument(
"--method",
@ -113,8 +111,7 @@ def read_sound_files(
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
@ -173,9 +170,7 @@ def main():
logging.info("Decoding started")
features = fbank(waves)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features = features.permute(0, 2, 1) # now features is [N, C, T]
with torch.no_grad():
@ -219,9 +214,7 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -49,12 +49,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
encode_supervisions,
setup_logger,
str2bool,
)
from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool
def get_parser():

View File

@ -47,9 +47,9 @@ def greedy_search(
device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
1, context_size
)
decoder_out = model.decoder(decoder_input, need_pad=False)
@ -81,9 +81,9 @@ def greedy_search(
y = logits.argmax().item()
if y != blank_id:
hyp.append(y)
decoder_input = torch.tensor(
[hyp[-context_size:]], device=device
).reshape(1, context_size)
decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
1, context_size
)
decoder_out = model.decoder(decoder_input, need_pad=False)
@ -157,9 +157,7 @@ class HypothesisList(object):
"""
if length_norm:
return max(
self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
)
return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
else:
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
@ -246,9 +244,9 @@ def beam_search(
device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
1, context_size
)
decoder_out = model.decoder(decoder_input, need_pad=False)

View File

@ -155,9 +155,7 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0
)
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
@ -175,18 +173,14 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm(
d_model
) # for the macaron style FNN module
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(
d_model
) # for the final output of the block
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
self.dropout = nn.Dropout(dropout)
@ -220,9 +214,7 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
if not self.normalize_before:
src = self.norm_ff_macaron(src)
@ -341,9 +333,7 @@ class RelPositionalEncoding(torch.nn.Module):
"""
def __init__(
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
@ -359,9 +349,7 @@ class RelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vector and `j` means the
@ -631,9 +619,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = nn.functional.linear(
query, in_proj_weight, in_proj_bias
).chunk(3, dim=-1)
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
3, dim=-1
)
elif torch.equal(key, value):
# encoder-decoder attention
@ -701,31 +689,22 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError(
"The size of the 2D attn_mask is not correct."
)
raise RuntimeError("The size of the 2D attn_mask is not correct.")
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
query.size(0),
key.size(0),
]:
raise RuntimeError(
"The size of the 3D attn_mask is not correct."
)
raise RuntimeError("The size of the 3D attn_mask is not correct.")
else:
raise RuntimeError(
"attn_mask's dimension {} is not supported".format(
attn_mask.dim()
)
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
)
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if (
key_padding_mask is not None
and key_padding_mask.dtype == torch.uint8
):
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
)
@ -764,9 +743,7 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
# compute matrix b and matrix d
matrix_bd = torch.matmul(
@ -778,9 +755,7 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1
)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
assert list(attn_output_weights.size()) == [
bsz * num_heads,
@ -814,13 +789,9 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = (
attn_output.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
@ -843,9 +814,7 @@ class ConvolutionModule(nn.Module):
"""
def __init__(
self, channels: int, kernel_size: int, bias: bool = True
) -> None:
def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding

View File

@ -99,8 +99,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -227,9 +226,7 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
batch_size = encoder_out.size(0)
@ -248,9 +245,7 @@ def decode_one_batch(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
hyps.append([lexicon.token_table[i] for i in hyp])
if params.decoding_method == "greedy_search":
@ -319,9 +314,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@ -346,9 +339,7 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -359,8 +350,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
@ -430,9 +420,7 @@ def main():
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
return
model.to(device)

View File

@ -86,9 +86,7 @@ class Decoder(nn.Module):
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
else:
# During inference time, there is no need to do extra padding
# as we only need one output

View File

@ -110,8 +110,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
return parser
@ -243,9 +242,7 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -103,9 +103,7 @@ class Transducer(nn.Module):
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens

View File

@ -117,8 +117,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -212,8 +211,7 @@ def read_sound_files(
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
@ -273,9 +271,7 @@ def main():
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
@ -319,9 +315,7 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -126,8 +126,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
@ -389,9 +388,7 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -504,9 +501,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@ -625,9 +620,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:

View File

@ -250,9 +250,7 @@ def _get_activation_fn(activation: str):
elif activation == "gelu":
return nn.functional.gelu
raise RuntimeError(
"activation should be relu/gelu, not {}".format(activation)
)
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
class PositionalEncoding(nn.Module):

View File

@ -29,10 +29,7 @@ from lhotse.dataset import (
K2SpeechRecognitionDataset,
SpecAugment,
)
from lhotse.dataset.input_strategies import (
OnTheFlyFeatures,
PrecomputedFeatures,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures
from torch.utils.data import DataLoader
from icefall.utils import str2bool
@ -162,9 +159,7 @@ class AsrDataModule:
if cuts_musan is not None:
logging.info("Enable MUSAN")
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
@ -173,9 +168,7 @@ class AsrDataModule:
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
@ -252,9 +245,7 @@ class AsrDataModule:
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:

View File

@ -170,8 +170,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
@ -227,9 +226,7 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
@ -241,10 +238,7 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -365,9 +359,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@ -393,9 +385,7 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -406,8 +396,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
@ -448,9 +437,7 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"

View File

@ -109,8 +109,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
return parser
@ -241,9 +240,7 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -165,8 +165,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -195,8 +194,7 @@ def read_sound_files(
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
@ -254,13 +252,9 @@ def main():
feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens
)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
num_waves = encoder_out.size(0)
hyp_list = []
@ -308,9 +302,7 @@ def main():
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.method}"
)
raise ValueError(f"Unsupported decoding method: {params.method}")
hyp_list.append(hyp)
hyps = []
@ -327,9 +319,7 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -149,8 +149,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
@ -168,8 +167,7 @@ def get_parser():
"--datatang-prob",
type=float,
default=0.2,
help="The probability to select a batch from the "
"aidatatang_200zh dataset",
help="The probability to select a batch from the " "aidatatang_200zh dataset",
)
return parser
@ -449,9 +447,7 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -605,9 +601,7 @@ def train_one_epoch(
f"train/current_{prefix}_",
params.batch_idx_train,
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
aishell_tot_loss.write_summary(
tb_writer, "train/aishell_tot_", params.batch_idx_train
)
@ -735,9 +729,7 @@ def run(rank, world_size, args):
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
if args.enable_musan:
cuts_musan = load_manifest(
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
)
cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
else:
cuts_musan = None
@ -776,9 +768,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:

View File

@ -171,8 +171,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
@ -231,9 +230,7 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
@ -245,10 +242,7 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -369,9 +363,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@ -397,9 +389,7 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -410,8 +400,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
@ -452,9 +441,7 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"

View File

@ -109,8 +109,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
return parser
@ -241,9 +240,7 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -165,8 +165,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -195,8 +194,7 @@ def read_sound_files(
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
@ -254,13 +252,9 @@ def main():
feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens
)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
num_waves = encoder_out.size(0)
hyp_list = []
@ -308,9 +302,7 @@ def main():
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.method}"
)
raise ValueError(f"Unsupported decoding method: {params.method}")
hyp_list.append(hyp)
hyps = []
@ -327,9 +319,7 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -142,8 +142,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
@ -414,9 +413,7 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -529,9 +526,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@ -657,9 +652,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:

0
egs/aishell2/ASR/local/__init__.py Executable file → Normal file
View File

View File

@ -83,9 +83,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
)
if "train" in partition:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
@ -111,9 +109,7 @@ def get_args():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)

View File

View File

@ -216,13 +216,9 @@ class AiShell2AsrDataModule:
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
@ -244,9 +240,7 @@ class AiShell2AsrDataModule:
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
@ -290,9 +284,7 @@ class AiShell2AsrDataModule:
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
@ -348,9 +340,7 @@ class AiShell2AsrDataModule:
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
@ -406,9 +396,7 @@ class AiShell2AsrDataModule:
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz")
return load_manifest_lazy(
self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz"
)
return load_manifest_lazy(self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz")
@lru_cache()
def test_cuts(self) -> CutSet:

View File

@ -269,8 +269,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -348,9 +347,7 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
if params.decoding_method == "fast_beam_search":
@ -409,10 +406,7 @@ def decode_one_batch(
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -538,9 +532,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@ -573,8 +565,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
@ -625,9 +616,7 @@ def main():
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -661,9 +650,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -690,9 +679,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -749,9 +738,7 @@ def main():
)
decoding_graph.scores *= params.ngram_lm_scale
else:
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None

View File

@ -133,8 +133,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
add_model_arguments(parser)
@ -167,9 +166,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -196,9 +195,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -266,9 +265,7 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -159,8 +159,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -192,8 +191,7 @@ def read_sound_files(
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
@ -254,15 +252,11 @@ def main():
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
num_waves = encoder_out.size(0)
hyps = []
@ -334,9 +328,7 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -92,9 +92,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def add_model_arguments(parser: argparse.ArgumentParser):
@ -220,8 +218,7 @@ def get_parser():
"--initial-lr",
type=float,
default=0.003,
help="The initial learning rate. This value should not need "
"to be changed.",
help="The initial learning rate. This value should not need " "to be changed.",
)
parser.add_argument(
@ -244,8 +241,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
@ -268,8 +264,7 @@ def get_parser():
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
help="The scale to smooth the loss with am (output of encoder network)" "part.",
)
parser.add_argument(
@ -603,11 +598,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@ -636,23 +627,16 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -771,9 +755,7 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
except: # noqa
display_and_save_batch(
batch, params=params, graph_compiler=graph_compiler
)
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
raise
if params.print_diagnostics and batch_idx == 5:
@ -829,9 +811,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@ -939,7 +919,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
@ -1104,9 +1084,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(
batch, params=params, graph_compiler=graph_compiler
)
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
raise

View File

@ -85,9 +85,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
)
if "train" in partition:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
@ -120,9 +118,7 @@ def get_args():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state
word = word2id[word]
pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False
def generate_lexicon(
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:

View File

@ -317,9 +317,7 @@ def lexicon_to_fst(
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
return parser.parse_args()

View File

@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst(
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")

View File

@ -56,9 +56,7 @@ def get_parser():
parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
)
parser.add_argument(
"--space", default="<space>", type=str, help="space symbol"
)
parser.add_argument("--space", default="<space>", type=str, help="space symbol")
parser.add_argument(
"--non-lang-syms",
"-l",
@ -66,9 +64,7 @@ def get_parser():
type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
)
parser.add_argument(
"text", type=str, default=False, nargs="?", help="input text"
)
parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
parser.add_argument(
"--trans_type",
"-t",
@ -108,8 +104,7 @@ def token2id(
if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list)
sub_ids = [
token_table[txt] if txt in token_table else oov_id
for txt in text
token_table[txt] if txt in token_table else oov_id for txt in text
]
ids.append(sub_ids)
else: # token_type = "pinyin"
@ -135,9 +130,7 @@ def main():
if args.text:
f = codecs.open(args.text, encoding="utf-8")
else:
f = codecs.getreader("utf-8")(
sys.stdin if is_python2 else sys.stdin.buffer
)
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer

View File

@ -222,17 +222,13 @@ class Aishell4AsrDataModule:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
@ -254,9 +250,7 @@ class Aishell4AsrDataModule:
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
@ -300,9 +294,7 @@ class Aishell4AsrDataModule:
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
@ -359,9 +351,7 @@ class Aishell4AsrDataModule:
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:

View File

@ -201,8 +201,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -260,9 +259,7 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
if params.decoding_method == "fast_beam_search":
@ -277,10 +274,7 @@ def decode_one_batch(
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -401,9 +395,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@ -436,8 +428,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
@ -480,9 +471,7 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -510,9 +499,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -543,9 +532,9 @@ def main():
)
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"

View File

@ -136,8 +136,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
add_model_arguments(parser)
@ -169,9 +168,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -202,9 +201,9 @@ def main():
)
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -276,9 +275,7 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -172,8 +172,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -205,8 +204,7 @@ def read_sound_files(
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
@ -266,15 +264,11 @@ def main():
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
num_waves = encoder_out.size(0)
hyps = []
@ -306,10 +300,7 @@ def main():
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -350,9 +341,7 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -85,9 +85,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def add_model_arguments(parser: argparse.ArgumentParser):
@ -213,8 +211,7 @@ def get_parser():
"--initial-lr",
type=float,
default=0.003,
help="The initial learning rate. This value should not need "
"to be changed.",
help="The initial learning rate. This value should not need " "to be changed.",
)
parser.add_argument(
@ -237,8 +234,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
@ -261,8 +257,7 @@ def get_parser():
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
help="The scale to smooth the loss with am (output of encoder network)" "part.",
)
parser.add_argument(
@ -599,11 +594,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@ -633,22 +624,15 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -827,9 +811,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@ -937,7 +919,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -84,9 +84,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
)
if "train" in partition:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cur_num_jobs = num_jobs if ex is None else 80
cur_num_jobs = min(cur_num_jobs, len(cut_set))
@ -121,9 +119,7 @@ def get_args():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state
word = word2id[word]
pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False
def generate_lexicon(
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:

View File

@ -317,9 +317,7 @@ def lexicon_to_fst(
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
return parser.parse_args()

View File

@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst(
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")

View File

@ -30,8 +30,8 @@ with word segmenting:
import argparse
import paddle
import jieba
import paddle
from tqdm import tqdm
paddle.enable_static()

View File

@ -56,9 +56,7 @@ def get_parser():
parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
)
parser.add_argument(
"--space", default="<space>", type=str, help="space symbol"
)
parser.add_argument("--space", default="<space>", type=str, help="space symbol")
parser.add_argument(
"--non-lang-syms",
"-l",
@ -66,9 +64,7 @@ def get_parser():
type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
)
parser.add_argument(
"text", type=str, default=False, nargs="?", help="input text"
)
parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
parser.add_argument(
"--trans_type",
"-t",
@ -108,8 +104,7 @@ def token2id(
if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list)
sub_ids = [
token_table[txt] if txt in token_table else oov_id
for txt in text
token_table[txt] if txt in token_table else oov_id for txt in text
]
ids.append(sub_ids)
else: # token_type = "pinyin"
@ -135,9 +130,7 @@ def main():
if args.text:
f = codecs.open(args.text, encoding="utf-8")
else:
f = codecs.getreader("utf-8")(
sys.stdin if is_python2 else sys.stdin.buffer
)
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer

View File

@ -205,17 +205,13 @@ class AlimeetingAsrDataModule:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
@ -237,9 +233,7 @@ class AlimeetingAsrDataModule:
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
@ -282,9 +276,7 @@ class AlimeetingAsrDataModule:
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
@ -341,9 +333,7 @@ class AlimeetingAsrDataModule:
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:

View File

@ -70,11 +70,7 @@ from beam_search import (
from lhotse.cut import Cut
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -193,8 +189,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -249,9 +244,7 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
if params.decoding_method == "fast_beam_search":
@ -266,10 +259,7 @@ def decode_one_batch(
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -390,9 +380,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@ -425,8 +413,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
@ -563,8 +550,7 @@ def main():
)
dev_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
]
cuts_dev_webdataset = CutSet.from_webdataset(
dev_shards,
@ -574,8 +560,7 @@ def main():
)
test_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
str(path) for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
]
cuts_test_webdataset = CutSet.from_webdataset(
test_shards,
@ -588,9 +573,7 @@ def main():
return 1.0 <= c.duration
cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt)
cuts_test_webdataset = cuts_test_webdataset.filter(
remove_short_and_long_utt
)
cuts_test_webdataset = cuts_test_webdataset.filter(remove_short_and_long_utt)
dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset)
test_dl = alimeeting.test_dataloaders(cuts_test_webdataset)

View File

@ -103,8 +103,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
return parser
@ -173,9 +172,7 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

Some files were not shown because too many files have changed in this diff Show More