apply black on all files

This commit is contained in:
Desh Raj 2022-11-17 09:42:17 -05:00
parent b3920e5ab5
commit 107df3b115
437 changed files with 3861 additions and 7334 deletions

View File

@ -45,17 +45,18 @@ jobs:
- name: Install Python dependencies - name: Install Python dependencies
run: | run: |
python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4 python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0
# See https://github.com/psf/black/issues/2964 # Click issue fixed in https://github.com/psf/black/pull/2966
# The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
- name: Run flake8 - name: Run flake8
shell: bash shell: bash
working-directory: ${{github.workspace}} working-directory: ${{github.workspace}}
run: | run: |
# stop the build if there are Python syntax errors or undefined names # stop the build if there are Python syntax errors or undefined names
flake8 . --count --show-source --statistics flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
flake8 . # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 \
--statistics --extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503
- name: Run black - name: Run black
shell: bash shell: bash

View File

@ -1,26 +1,38 @@
repos: repos:
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 21.6b0 rev: 22.3.0
hooks: hooks:
- id: black - id: black
args: [--line-length=80] args: ["--line-length=88"]
additional_dependencies: ['click==8.0.1'] additional_dependencies: ['click==8.1.0']
exclude: icefall\/__init__\.py exclude: icefall\/__init__\.py
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 3.9.2 rev: 5.0.4
hooks: hooks:
- id: flake8 - id: flake8
args: [--max-line-length=80] args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"]
# What are we ignoring here?
# E203: whitespace before ':'
# E266: too many leading '#' for block comment
# E501: line too long
# F401: module imported but unused
# E402: module level import not at top of file
# F403: 'from module import *' used; unable to detect undefined names
# F841: local variable is assigned to but never used
# W503: line break before binary operator
# In addition, the default ignore list is:
# E121,E123,E126,E226,E24,E704,W503,W504
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.9.2 rev: 5.10.1
hooks: hooks:
- id: isort - id: isort
args: [--profile=black, --line-length=80] args: ["--profile=black"]
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1 rev: v4.2.0
hooks: hooks:
- id: check-executables-have-shebangs - id: check-executables-have-shebangs
- id: end-of-file-fixer - id: end-of-file-fixer

View File

@ -2,7 +2,7 @@
2 sets of configuration are provided - (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8, and (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. 2 sets of configuration are provided - (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8, and (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8.
If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8. If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8.
Otherwise, since the older PyTorch images are not updated with the [apt-key rotation by NVIDIA](https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key), you have to go for case (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. Ensure that your NVDIA driver supports at least CUDA 11.0. Otherwise, since the older PyTorch images are not updated with the [apt-key rotation by NVIDIA](https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key), you have to go for case (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. Ensure that your NVDIA driver supports at least CUDA 11.0.
@ -10,7 +10,7 @@ You can check the highest CUDA version within your NVIDIA driver's support with
```bash ```bash
$ nvidia-smi $ nvidia-smi
Tue Sep 20 00:26:13 2022 Tue Sep 20 00:26:13 2022
+-----------------------------------------------------------------------------+ +-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.119.03 Driver Version: 450.119.03 CUDA Version: 11.0 | | NVIDIA-SMI 450.119.03 Driver Version: 450.119.03 CUDA Version: 11.0 |
|-------------------------------+----------------------+----------------------+ |-------------------------------+----------------------+----------------------+
@ -26,7 +26,7 @@ Tue Sep 20 00:26:13 2022
| 41% 30C P8 11W / 280W | 6MiB / 24220MiB | 0% Default | | 41% 30C P8 11W / 280W | 6MiB / 24220MiB | 0% Default |
| | | N/A | | | | N/A |
+-------------------------------+----------------------+----------------------+ +-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+ +-----------------------------------------------------------------------------+
| Processes: | | Processes: |
| GPU GI CI PID Type Process name GPU Memory | | GPU GI CI PID Type Process name GPU Memory |
@ -40,15 +40,15 @@ Tue Sep 20 00:26:13 2022
``` ```
## Building images locally ## Building images locally
If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly. If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly.
For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details. For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details.
```dockerfile ```dockerfile
ENV http_proxy=http://aaa.bb.cc.net:8080 \ ENV http_proxy=http://aaa.bb.cc.net:8080 \
https_proxy=http://aaa.bb.cc.net:8080 https_proxy=http://aaa.bb.cc.net:8080
``` ```
Then, proceed with these commands. Then, proceed with these commands.
### If you are case (a), i.e. your NVIDIA driver supports CUDA version >= 11.3: ### If you are case (a), i.e. your NVIDIA driver supports CUDA version >= 11.3:
@ -72,11 +72,11 @@ docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all icefall
``` ```
### Tips: ### Tips:
1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`. 1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`.
2. Also, if your environment requires a proxy, this would be a good time to add it in too: `-e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080`. 2. Also, if your environment requires a proxy, this would be a good time to add it in too: `-e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080`.
Overall, your docker run command should look like this. Overall, your docker run command should look like this.
```bash ```bash
docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/host/machine}:{/path/in/docker} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1 docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/host/machine}:{/path/in/docker} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1
@ -86,9 +86,9 @@ You can explore more docker run options [here](https://docs.docker.com/engine/re
### Linking to icefall in your host machine ### Linking to icefall in your host machine
If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container. If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container.
Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine. Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine.
Warning: Check that the icefall in your host machine is visible from within your container before proceeding to the commands below. Warning: Check that the icefall in your host machine is visible from within your container before proceeding to the commands below.
Use these commands once you are inside the container. Use these commands once you are inside the container.
@ -103,7 +103,7 @@ ln -s {/path/in/docker/to/icefall} /workspace/icefall
docker exec -it icefall /bin/bash docker exec -it icefall /bin/bash
``` ```
## Restarting a killed container that has been run before. ## Restarting a killed container that has been run before.
```bash ```bash
docker start -ai icefall docker start -ai icefall
``` ```
@ -111,4 +111,4 @@ docker start -ai icefall
## Sample usage of the CPU based images: ## Sample usage of the CPU based images:
```bash ```bash
docker run -it icefall /bin/bash docker run -it icefall /bin/bash
``` ```

View File

@ -1,7 +1,7 @@
FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel
# ENV http_proxy=http://aaa.bbb.cc.net:8080 \ # ENV http_proxy=http://aaa.bbb.cc.net:8080 \
# https_proxy=http://aaa.bbb.cc.net:8080 # https_proxy=http://aaa.bbb.cc.net:8080
# install normal source # install normal source
RUN apt-get update && \ RUN apt-get update && \
@ -38,10 +38,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
rm -rf cmake-3.18.0.tar.gz && \ rm -rf cmake-3.18.0.tar.gz && \
find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
cd - cd -
# flac # flac
RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \
cd /opt && \ cd /opt && \
xz -d flac-1.3.2.tar.xz && \ xz -d flac-1.3.2.tar.xz && \
tar -xvf flac-1.3.2.tar && \ tar -xvf flac-1.3.2.tar && \
cd flac-1.3.2 && \ cd flac-1.3.2 && \
@ -49,11 +49,11 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz &&
make && make install && \ make && make install && \
rm -rf flac-1.3.2.tar && \ rm -rf flac-1.3.2.tar && \
find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
cd - cd -
RUN conda install -y -c pytorch torchaudio=0.12 && \ RUN conda install -y -c pytorch torchaudio=0.12 && \
pip install graphviz pip install graphviz
#install k2 from source #install k2 from source
RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \ RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
@ -68,7 +68,7 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \ cd /workspace/icefall && \
pip install -r requirements.txt pip install -r requirements.txt
RUN pip install kaldifeat RUN pip install kaldifeat
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall WORKDIR /workspace/icefall

View File

@ -1,12 +1,12 @@
FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel
# ENV http_proxy=http://aaa.bbb.cc.net:8080 \ # ENV http_proxy=http://aaa.bbb.cc.net:8080 \
# https_proxy=http://aaa.bbb.cc.net:8080 # https_proxy=http://aaa.bbb.cc.net:8080
RUN rm /etc/apt/sources.list.d/cuda.list && \ RUN rm /etc/apt/sources.list.d/cuda.list && \
rm /etc/apt/sources.list.d/nvidia-ml.list && \ rm /etc/apt/sources.list.d/nvidia-ml.list && \
apt-key del 7fa2af80 apt-key del 7fa2af80
# install normal source # install normal source
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --no-install-recommends \ apt-get install -y --no-install-recommends \
@ -36,7 +36,7 @@ RUN curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu18
curl -fsSL https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \ curl -fsSL https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \
echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \ echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \
echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \ echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \
rm -rf /var/lib/apt/lists/* && \ rm -rf /var/lib/apt/lists/* && \
mv /opt/conda/lib/libcufft.so.10 /opt/libcufft.so.10.bak && \ mv /opt/conda/lib/libcufft.so.10 /opt/libcufft.so.10.bak && \
mv /opt/conda/lib/libcurand.so.10 /opt/libcurand.so.10.bak && \ mv /opt/conda/lib/libcurand.so.10 /opt/libcurand.so.10.bak && \
mv /opt/conda/lib/libcublas.so.11 /opt/libcublas.so.11.bak && \ mv /opt/conda/lib/libcublas.so.11 /opt/libcublas.so.11.bak && \
@ -56,10 +56,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
rm -rf cmake-3.18.0.tar.gz && \ rm -rf cmake-3.18.0.tar.gz && \
find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
cd - cd -
# flac # flac
RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \
cd /opt && \ cd /opt && \
xz -d flac-1.3.2.tar.xz && \ xz -d flac-1.3.2.tar.xz && \
tar -xvf flac-1.3.2.tar && \ tar -xvf flac-1.3.2.tar && \
cd flac-1.3.2 && \ cd flac-1.3.2 && \
@ -67,7 +67,7 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz &&
make && make install && \ make && make install && \
rm -rf flac-1.3.2.tar && \ rm -rf flac-1.3.2.tar && \
find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
cd - cd -
RUN conda install -y -c pytorch torchaudio=0.7.1 && \ RUN conda install -y -c pytorch torchaudio=0.7.1 && \
pip install graphviz pip install graphviz
@ -79,7 +79,7 @@ RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
cd - cd -
# install lhotse # install lhotse
RUN pip install git+https://github.com/lhotse-speech/lhotse RUN pip install git+https://github.com/lhotse-speech/lhotse
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \ cd /workspace/icefall && \
@ -88,4 +88,3 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall WORKDIR /workspace/icefall

View File

@ -1 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="80" height="20" role="img" aria-label="k2: &gt;= v1.9"><title>k2: &gt;= v1.9</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="80" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="23" height="20" fill="#555"/><rect x="23" width="57" height="20" fill="blueviolet"/><rect width="80" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="125" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="130">k2</text><text x="125" y="140" transform="scale(.1)" fill="#fff" textLength="130">k2</text><text aria-hidden="true" x="505" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="470">&gt;= v1.9</text><text x="505" y="140" transform="scale(.1)" fill="#fff" textLength="470">&gt;= v1.9</text></g></svg> <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="80" height="20" role="img" aria-label="k2: &gt;= v1.9"><title>k2: &gt;= v1.9</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="80" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="23" height="20" fill="#555"/><rect x="23" width="57" height="20" fill="blueviolet"/><rect width="80" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="125" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="130">k2</text><text x="125" y="140" transform="scale(.1)" fill="#fff" textLength="130">k2</text><text aria-hidden="true" x="505" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="470">&gt;= v1.9</text><text x="505" y="140" transform="scale(.1)" fill="#fff" textLength="470">&gt;= v1.9</text></g></svg>

Before

Width:  |  Height:  |  Size: 1.1 KiB

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -1 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="98" height="20" role="img" aria-label="python: &gt;= 3.6"><title>python: &gt;= 3.6</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="98" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="49" height="20" fill="#555"/><rect x="49" width="49" height="20" fill="#007ec6"/><rect width="98" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="255" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="390">python</text><text x="255" y="140" transform="scale(.1)" fill="#fff" textLength="390">python</text><text aria-hidden="true" x="725" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="390">&gt;= 3.6</text><text x="725" y="140" transform="scale(.1)" fill="#fff" textLength="390">&gt;= 3.6</text></g></svg> <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="98" height="20" role="img" aria-label="python: &gt;= 3.6"><title>python: &gt;= 3.6</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="98" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="49" height="20" fill="#555"/><rect x="49" width="49" height="20" fill="#007ec6"/><rect width="98" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="255" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="390">python</text><text x="255" y="140" transform="scale(.1)" fill="#fff" textLength="390">python</text><text aria-hidden="true" x="725" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="390">&gt;= 3.6</text><text x="725" y="140" transform="scale(.1)" fill="#fff" textLength="390">&gt;= 3.6</text></g></svg>

Before

Width:  |  Height:  |  Size: 1.1 KiB

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -1 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="100" height="20" role="img" aria-label="torch: &gt;= 1.6.0"><title>torch: &gt;= 1.6.0</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="100" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="39" height="20" fill="#555"/><rect x="39" width="61" height="20" fill="#97ca00"/><rect width="100" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="205" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="290">torch</text><text x="205" y="140" transform="scale(.1)" fill="#fff" textLength="290">torch</text><text aria-hidden="true" x="685" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="510">&gt;= 1.6.0</text><text x="685" y="140" transform="scale(.1)" fill="#fff" textLength="510">&gt;= 1.6.0</text></g></svg> <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="100" height="20" role="img" aria-label="torch: &gt;= 1.6.0"><title>torch: &gt;= 1.6.0</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="100" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="39" height="20" fill="#555"/><rect x="39" width="61" height="20" fill="#97ca00"/><rect width="100" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="205" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="290">torch</text><text x="205" y="140" transform="scale(.1)" fill="#fff" textLength="290">torch</text><text aria-hidden="true" x="685" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="510">&gt;= 1.6.0</text><text x="685" y="140" transform="scale(.1)" fill="#fff" textLength="510">&gt;= 1.6.0</text></g></svg>

Before

Width:  |  Height:  |  Size: 1.1 KiB

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -19,4 +19,3 @@ It can be downloaded from `<https://www.openslr.org/33/>`_
tdnn_lstm_ctc tdnn_lstm_ctc
conformer_ctc conformer_ctc
stateless_transducer stateless_transducer

View File

@ -6,4 +6,3 @@ TIMIT
tdnn_ligru_ctc tdnn_ligru_ctc
tdnn_lstm_ctc tdnn_lstm_ctc

View File

@ -148,10 +148,10 @@ Some commonly used options are:
$ ./tdnn_ligru_ctc/decode.py --epoch 25 --avg 17 $ ./tdnn_ligru_ctc/decode.py --epoch 25 --avg 17
uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``, uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``,
``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``, ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``,
``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``, ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``,
``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``,
``epoch-24.pt`` and ``epoch-25.pt`` ``epoch-24.pt`` and ``epoch-25.pt``
for decoding. for decoding.
@ -317,13 +317,13 @@ To decode with ``1best`` method, we can use:
.. code-block:: bash .. code-block:: bash
./tdnn_ligru_ctc/pretrained.py ./tdnn_ligru_ctc/pretrained.py
--method 1best --method 1best
--checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt
--words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt
--HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV
The output is: The output is:
@ -337,7 +337,7 @@ The output is:
2021-11-08 20:41:38,697 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 20:41:38,697 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV']
2021-11-08 20:41:38,704 INFO [pretrained.py:216] Decoding started 2021-11-08 20:41:38,704 INFO [pretrained.py:216] Decoding started
2021-11-08 20:41:39,819 INFO [pretrained.py:246] Use HLG decoding 2021-11-08 20:41:39,819 INFO [pretrained.py:246] Use HLG decoding
2021-11-08 20:41:39,829 INFO [pretrained.py:267] 2021-11-08 20:41:39,829 INFO [pretrained.py:267]
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV: ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV:
sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx ih w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx ih w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh
@ -362,8 +362,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use
--HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt \ --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt \
--G ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lm/G_4_gram.pt \ --G ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 0.1 \ --ngram-lm-scale 0.1 \
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV
The decoding output is: The decoding output is:
@ -378,7 +378,7 @@ The decoding output is:
2021-11-08 20:37:54,715 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 20:37:54,715 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV']
2021-11-08 20:37:54,720 INFO [pretrained.py:216] Decoding started 2021-11-08 20:37:54,720 INFO [pretrained.py:216] Decoding started
2021-11-08 20:37:55,808 INFO [pretrained.py:251] Use HLG decoding + LM rescoring 2021-11-08 20:37:55,808 INFO [pretrained.py:251] Use HLG decoding + LM rescoring
2021-11-08 20:37:56,348 INFO [pretrained.py:267] 2021-11-08 20:37:56,348 INFO [pretrained.py:267]
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV: ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV:
sil dh ih sh uw ah l iy v iy z ah sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy ng w ih th ih n ih m s eh l f sil jh sil dh ih sh uw ah l iy v iy z ah sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy ng w ih th ih n ih m s eh l f sil jh

View File

@ -148,8 +148,8 @@ Some commonly used options are:
$ ./tdnn_lstm_ctc/decode.py --epoch 25 --avg 10 $ ./tdnn_lstm_ctc/decode.py --epoch 25 --avg 10
uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``,
``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``,
``epoch-23.pt``, ``epoch-24.pt`` and ``epoch-25.pt`` ``epoch-23.pt``, ``epoch-24.pt`` and ``epoch-25.pt``
for decoding. for decoding.
@ -315,13 +315,13 @@ To decode with ``1best`` method, we can use:
.. code-block:: bash .. code-block:: bash
./tdnn_lstm_ctc/pretrained.py ./tdnn_lstm_ctc/pretrained.py
--method 1best --method 1best
--checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt
--words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt
--HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
The output is: The output is:
@ -335,7 +335,7 @@ The output is:
2021-11-08 21:02:53,827 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 21:02:53,827 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV']
2021-11-08 21:02:53,831 INFO [pretrained.py:216] Decoding started 2021-11-08 21:02:53,831 INFO [pretrained.py:216] Decoding started
2021-11-08 21:02:54,380 INFO [pretrained.py:246] Use HLG decoding 2021-11-08 21:02:54,380 INFO [pretrained.py:246] Use HLG decoding
2021-11-08 21:02:54,387 INFO [pretrained.py:267] 2021-11-08 21:02:54,387 INFO [pretrained.py:267]
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV: ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV:
sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy w ih th ih n ih m s eh l f sil jh sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy w ih th ih n ih m s eh l f sil jh
@ -360,8 +360,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use
--HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt \ --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt \
--G ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lm/G_4_gram.pt \ --G ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 0.08 \ --ngram-lm-scale 0.08 \
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
The decoding output is: The decoding output is:
@ -376,7 +376,7 @@ The decoding output is:
2021-11-08 20:05:26,978 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 20:05:26,978 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV']
2021-11-08 20:05:26,981 INFO [pretrained.py:216] Decoding started 2021-11-08 20:05:26,981 INFO [pretrained.py:216] Decoding started
2021-11-08 20:05:27,519 INFO [pretrained.py:251] Use HLG decoding + LM rescoring 2021-11-08 20:05:27,519 INFO [pretrained.py:251] Use HLG decoding + LM rescoring
2021-11-08 20:05:27,878 INFO [pretrained.py:267] 2021-11-08 20:05:27,878 INFO [pretrained.py:267]
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV: ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV:
sil dh ih sh uw l iy v iy z ih sil p r aa sil k s ah m ey dx ih sil w uh dx iy w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh sil dh ih sh uw l iy v iy z ih sil p r aa sil k s ah m ey dx ih sil w uh dx iy w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh

View File

@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
) )
if "train" in partition: if "train" in partition:
cut_set = ( cut_set = (
cut_set cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
) )
cut_set = cut_set.compute_and_store_features( cut_set = cut_set.compute_and_store_features(
extractor=extractor, extractor=extractor,
@ -116,9 +114,7 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state cur_state = loop_state
word = word2id[word] word = word2id[word]
pieces = [ pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1): for i in range(len(pieces) - 1):
w = word if i == 0 else eps w = word if i == 0 else eps
@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False return False
def generate_lexicon( def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table. """Generate a lexicon from a word list and token_sym_table.
Args: Args:

View File

@ -317,9 +317,7 @@ def lexicon_to_fst(
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
return parser.parse_args() return parser.parse_args()

View File

@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L") fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst( fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig") fsa_disambig.draw("L_disambig.pdf", title="L_disambig")

View File

@ -56,9 +56,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns" "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
) )
parser.add_argument( parser.add_argument("--space", default="<space>", type=str, help="space symbol")
"--space", default="<space>", type=str, help="space symbol"
)
parser.add_argument( parser.add_argument(
"--non-lang-syms", "--non-lang-syms",
"-l", "-l",
@ -66,9 +64,7 @@ def get_parser():
type=str, type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.", help="list of non-linguistic symobles, e.g., <NOISE> etc.",
) )
parser.add_argument( parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
"text", type=str, default=False, nargs="?", help="input text"
)
parser.add_argument( parser.add_argument(
"--trans_type", "--trans_type",
"-t", "-t",
@ -108,8 +104,7 @@ def token2id(
if token_type == "lazy_pinyin": if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list) text = lazy_pinyin(chars_list)
sub_ids = [ sub_ids = [
token_table[txt] if txt in token_table else oov_id token_table[txt] if txt in token_table else oov_id for txt in text
for txt in text
] ]
ids.append(sub_ids) ids.append(sub_ids)
else: # token_type = "pinyin" else: # token_type = "pinyin"
@ -135,9 +130,7 @@ def main():
if args.text: if args.text:
f = codecs.open(args.text, encoding="utf-8") f = codecs.open(args.text, encoding="utf-8")
else: else:
f = codecs.getreader("utf-8")( f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
sys.stdin if is_python2 else sys.stdin.buffer
)
sys.stdout = codecs.getwriter("utf-8")( sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer sys.stdout if is_python2 else sys.stdout.buffer

View File

@ -106,11 +106,10 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
if [ ! -f $lang_char_dir/words.txt ]; then if [ ! -f $lang_char_dir/words.txt ]; then
./local/prepare_words.py \ ./local/prepare_words.py \
--input-file $lang_char_dir/words_no_ids.txt \ --input-file $lang_char_dir/words_no_ids.txt \
--output-file $lang_char_dir/words.txt --output-file $lang_char_dir/words.txt
fi fi
if [ ! -f $lang_char_dir/L_disambig.pt ]; then if [ ! -f $lang_char_dir/L_disambig.pt ]; then
./local/prepare_char.py ./local/prepare_char.py
fi fi
fi fi

View File

@ -205,17 +205,13 @@ class Aidatatang_200zhAsrDataModule:
The state dict for the training sampler. The state dict for the training sampler.
""" """
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest( cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = [] transforms = []
if self.args.enable_musan: if self.args.enable_musan:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
transforms.append( transforms.append(
CutMix( CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
@ -237,9 +233,7 @@ class Aidatatang_200zhAsrDataModule:
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
logging.info( logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version. # Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is # In different Lhotse's versions, the default of num_frame_masks is
# different. # different.
@ -282,9 +276,7 @@ class Aidatatang_200zhAsrDataModule:
# Drop feats to be on the safe side. # Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -340,9 +332,7 @@ class Aidatatang_200zhAsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:

View File

@ -69,11 +69,7 @@ from beam_search import (
) )
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -192,8 +188,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -249,9 +244,7 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
x=feature, x_lens=feature_lens
)
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
@ -266,10 +259,7 @@ def decode_one_batch(
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -390,9 +380,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -425,8 +413,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)

View File

@ -103,8 +103,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
return parser return parser
@ -173,9 +172,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -162,8 +162,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -194,8 +193,7 @@ def read_sound_files(
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. " f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
f"Given: {sample_rate}"
) )
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
@ -257,9 +255,7 @@ def main():
features = fbank(waves) features = fbank(waves)
feature_lengths = [f.size(0) for f in features] feature_lengths = [f.size(0) for f in features]
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
@ -284,10 +280,7 @@ def main():
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -339,9 +332,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -81,9 +81,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
@ -187,8 +185,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -211,8 +208,7 @@ def get_parser():
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" help="The scale to smooth the loss with am (output of encoder network)" "part.",
"part.",
) )
parser.add_argument( parser.add_argument(
@ -542,22 +538,15 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = ( pruned_loss_scale = (
0.0 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -711,9 +700,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -813,7 +800,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2 ** 22 2**22
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True, normalize_before: bool = True,
) -> None: ) -> None:
super(ConformerEncoderLayer, self).__init__() super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention( self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
d_model, nhead, dropout=0.0
)
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm( self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
d_model
) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5 self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm( self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
d_model
) # for the final output of the block
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module):
residual = src residual = src
if self.normalize_before: if self.normalize_before:
src = self.norm_ff_macaron(src) src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout( src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
self.feed_forward_macaron(src)
)
if not self.normalize_before: if not self.normalize_before:
src = self.norm_ff_macaron(src) src = self.norm_ff_macaron(src)
@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module):
""" """
def __init__( def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
"""Construct an PositionalEncoding object.""" """Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__() super(RelPositionalEncoding, self).__init__()
self.d_model = d_model self.d_model = d_model
@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1 # the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1: if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device # Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str( if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device) self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return return
# Suppose `i` means to the position of query vector and `j` means the # Suppose `i` means to the position of query vector and `j` means the
@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value): if torch.equal(query, key) and torch.equal(key, value):
# self-attention # self-attention
q, k, v = nn.functional.linear( q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
query, in_proj_weight, in_proj_bias 3, dim=-1
).chunk(3, dim=-1) )
elif torch.equal(key, value): elif torch.equal(key, value):
# encoder-decoder attention # encoder-decoder attention
@ -703,31 +691,22 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0) attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError( raise RuntimeError("The size of the 2D attn_mask is not correct.")
"The size of the 2D attn_mask is not correct."
)
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [ if list(attn_mask.size()) != [
bsz * num_heads, bsz * num_heads,
query.size(0), query.size(0),
key.size(0), key.size(0),
]: ]:
raise RuntimeError( raise RuntimeError("The size of the 3D attn_mask is not correct.")
"The size of the 3D attn_mask is not correct."
)
else: else:
raise RuntimeError( raise RuntimeError(
"attn_mask's dimension {} is not supported".format( "attn_mask's dimension {} is not supported".format(attn_mask.dim())
attn_mask.dim()
)
) )
# attn_mask's dim is 3 now. # attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool # convert ByteTensor key_padding_mask to bool
if ( if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
key_padding_mask is not None
and key_padding_mask.dtype == torch.uint8
):
warnings.warn( warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
) )
@ -766,9 +745,7 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c # first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul( matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d # compute matrix b and matrix d
matrix_bd = torch.matmul( matrix_bd = torch.matmul(
@ -780,9 +757,7 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2) ) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
bsz * num_heads, tgt_len, -1
)
assert list(attn_output_weights.size()) == [ assert list(attn_output_weights.size()) == [
bsz * num_heads, bsz * num_heads,
@ -816,13 +791,9 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = torch.bmm(attn_output_weights, v) attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = ( attn_output = (
attn_output.transpose(0, 1) attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
) )
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights: if need_weights:
# average attention weights over heads # average attention weights over heads
@ -845,9 +816,7 @@ class ConvolutionModule(nn.Module):
""" """
def __init__( def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
self, channels: int, kernel_size: int, bias: bool = True
) -> None:
"""Construct an ConvolutionModule object.""" """Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__() super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding # kernerl_size should be a odd number for 'SAME' padding

View File

@ -401,9 +401,7 @@ def decode_dataset(
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -431,9 +429,7 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append( results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@ -441,9 +437,7 @@ def save_results(
test_set_wers[key] = wer test_set_wers[key] = wer
if enable_log: if enable_log:
logging.info( logging.info("Wrote detailed error stats to {}".format(errs_filename))
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
@ -562,9 +556,7 @@ def main():
eos_id=eos_id, eos_id=eos_id,
) )
save_results( save_results(params=params, test_set_name=test_set, results_dict=results_dict)
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!") logging.info("Done!")

View File

@ -157,9 +157,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -211,8 +211,7 @@ def read_sound_files(
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. " f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
f"Given: {sample_rate}"
) )
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
@ -274,9 +273,7 @@ def main():
logging.info("Decoding started") logging.info("Decoding started")
features = fbank(waves) features = fbank(waves)
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
# Note: We don't use key padding mask for attention during decoding # Note: We don't use key padding mask for attention during decoding
with torch.no_grad(): with torch.no_grad():
@ -371,9 +368,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module):
assert idim >= 7 assert idim >= 7
super().__init__() super().__init__()
self.conv = nn.Sequential( self.conv = nn.Sequential(
nn.Conv2d( nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
in_channels=1, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(), nn.ReLU(),
nn.Conv2d( nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(), nn.ReLU(),
) )
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@ -132,17 +128,13 @@ class VggSubsampling(nn.Module):
) )
) )
layers.append( layers.append(
torch.nn.MaxPool2d( torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
kernel_size=2, stride=2, padding=0, ceil_mode=True
)
) )
cur_channels = block_dim cur_channels = block_dim
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
self.out = nn.Linear( self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x. """Subsample x.

View File

@ -16,9 +16,8 @@
# limitations under the License. # limitations under the License.
from subsampling import Conv2dSubsampling
from subsampling import VggSubsampling
import torch import torch
from subsampling import Conv2dSubsampling, VggSubsampling
def test_conv2d_subsampling(): def test_conv2d_subsampling():

View File

@ -382,9 +382,7 @@ def compute_loss(
# #
# See https://github.com/k2-fsa/icefall/issues/97 # See https://github.com/k2-fsa/icefall/issues/97
# for more details # for more details
unsorted_token_ids = graph_compiler.texts_to_ids( unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
supervisions["text"]
)
att_loss = mmodel.decoder_forward( att_loss = mmodel.decoder_forward(
encoder_memory, encoder_memory,
memory_mask, memory_mask,
@ -520,9 +518,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -630,9 +626,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if rank == 0:

View File

@ -149,9 +149,7 @@ class Transformer(nn.Module):
norm=decoder_norm, norm=decoder_norm,
) )
self.decoder_output_layer = torch.nn.Linear( self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
d_model, self.decoder_num_class
)
self.decoder_criterion = LabelSmoothingLoss() self.decoder_criterion = LabelSmoothingLoss()
else: else:
@ -183,9 +181,7 @@ class Transformer(nn.Module):
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x) x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder( encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
x, supervision
)
x = self.ctc_output(encoder_memory) x = self.ctc_output(encoder_memory)
return x, encoder_memory, memory_key_padding_mask return x, encoder_memory, memory_key_padding_mask
@ -266,23 +262,17 @@ class Transformer(nn.Module):
""" """
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence( ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out] ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence( ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device device = memory.device
ys_in_pad = ys_in_pad.to(device) ys_in_pad = ys_in_pad.to(device)
ys_out_pad = ys_out_pad.to(device) ys_out_pad = ys_out_pad.to(device)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
device
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask # TODO: Use length information to create the decoder padding mask
@ -343,23 +333,17 @@ class Transformer(nn.Module):
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence( ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out] ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence( ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
device
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask # TODO: Use length information to create the decoder padding mask
@ -632,9 +616,7 @@ def _get_activation_fn(activation: str):
elif activation == "gelu": elif activation == "gelu":
return nn.functional.gelu return nn.functional.gelu
raise RuntimeError( raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
"activation should be relu/gelu, not {}".format(activation)
)
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
@ -836,9 +818,7 @@ def encoder_padding_mask(
1, 1,
).to(torch.int32) ).to(torch.int32)
lengths = [ lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
]
for idx in range(supervision_segments.size(0)): for idx in range(supervision_segments.size(0)):
# Note: TorchScript doesn't allow to unpack tensors as tuples # Note: TorchScript doesn't allow to unpack tensors as tuples
sequence_idx = supervision_segments[idx, 0].item() sequence_idx = supervision_segments[idx, 0].item()
@ -859,9 +839,7 @@ def encoder_padding_mask(
return mask return mask
def decoder_padding_mask( def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
ys_pad: torch.Tensor, ignore_id: int = -1
) -> torch.Tensor:
"""Generate a length mask for input. """Generate a length mask for input.
The masked position are filled with True, The masked position are filled with True,

View File

@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True, normalize_before: bool = True,
) -> None: ) -> None:
super(ConformerEncoderLayer, self).__init__() super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention( self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
d_model, nhead, dropout=0.0
)
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm( self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
d_model
) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5 self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm( self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
d_model
) # for the final output of the block
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module):
residual = src residual = src
if self.normalize_before: if self.normalize_before:
src = self.norm_ff_macaron(src) src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout( src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
self.feed_forward_macaron(src)
)
if not self.normalize_before: if not self.normalize_before:
src = self.norm_ff_macaron(src) src = self.norm_ff_macaron(src)
@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module):
""" """
def __init__( def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
"""Construct an PositionalEncoding object.""" """Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__() super(RelPositionalEncoding, self).__init__()
self.d_model = d_model self.d_model = d_model
@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1 # the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1: if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device # Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str( if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device) self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return return
# Suppose `i` means to the position of query vector and `j` means the # Suppose `i` means to the position of query vector and `j` means the
@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value): if torch.equal(query, key) and torch.equal(key, value):
# self-attention # self-attention
q, k, v = nn.functional.linear( q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
query, in_proj_weight, in_proj_bias 3, dim=-1
).chunk(3, dim=-1) )
elif torch.equal(key, value): elif torch.equal(key, value):
# encoder-decoder attention # encoder-decoder attention
@ -703,31 +691,22 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0) attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError( raise RuntimeError("The size of the 2D attn_mask is not correct.")
"The size of the 2D attn_mask is not correct."
)
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [ if list(attn_mask.size()) != [
bsz * num_heads, bsz * num_heads,
query.size(0), query.size(0),
key.size(0), key.size(0),
]: ]:
raise RuntimeError( raise RuntimeError("The size of the 3D attn_mask is not correct.")
"The size of the 3D attn_mask is not correct."
)
else: else:
raise RuntimeError( raise RuntimeError(
"attn_mask's dimension {} is not supported".format( "attn_mask's dimension {} is not supported".format(attn_mask.dim())
attn_mask.dim()
)
) )
# attn_mask's dim is 3 now. # attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool # convert ByteTensor key_padding_mask to bool
if ( if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
key_padding_mask is not None
and key_padding_mask.dtype == torch.uint8
):
warnings.warn( warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
) )
@ -766,9 +745,7 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c # first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul( matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d # compute matrix b and matrix d
matrix_bd = torch.matmul( matrix_bd = torch.matmul(
@ -780,9 +757,7 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2) ) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
bsz * num_heads, tgt_len, -1
)
assert list(attn_output_weights.size()) == [ assert list(attn_output_weights.size()) == [
bsz * num_heads, bsz * num_heads,
@ -816,13 +791,9 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = torch.bmm(attn_output_weights, v) attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = ( attn_output = (
attn_output.transpose(0, 1) attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
) )
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights: if need_weights:
# average attention weights over heads # average attention weights over heads
@ -845,9 +816,7 @@ class ConvolutionModule(nn.Module):
""" """
def __init__( def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
self, channels: int, kernel_size: int, bias: bool = True
) -> None:
"""Construct an ConvolutionModule object.""" """Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__() super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding # kernerl_size should be a odd number for 'SAME' padding

View File

@ -413,9 +413,7 @@ def decode_dataset(
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -443,9 +441,7 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append( results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@ -453,9 +449,7 @@ def save_results(
test_set_wers[key] = wer test_set_wers[key] = wer
if enable_log: if enable_log:
logging.info( logging.info("Wrote detailed error stats to {}".format(errs_filename))
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
@ -550,9 +544,7 @@ def main():
if params.export: if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save( torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return return
model.to(device) model.to(device)
@ -581,9 +573,7 @@ def main():
eos_id=eos_id, eos_id=eos_id,
) )
save_results( save_results(params=params, test_set_name=test_set, results_dict=results_dict)
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!") logging.info("Done!")

View File

@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module):
assert idim >= 7 assert idim >= 7
super().__init__() super().__init__()
self.conv = nn.Sequential( self.conv = nn.Sequential(
nn.Conv2d( nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
in_channels=1, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(), nn.ReLU(),
nn.Conv2d( nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(), nn.ReLU(),
) )
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@ -132,17 +128,13 @@ class VggSubsampling(nn.Module):
) )
) )
layers.append( layers.append(
torch.nn.MaxPool2d( torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
kernel_size=2, stride=2, padding=0, ceil_mode=True
)
) )
cur_channels = block_dim cur_channels = block_dim
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
self.out = nn.Linear( self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x. """Subsample x.

View File

@ -511,9 +511,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -625,9 +623,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if rank == 0:

View File

@ -149,9 +149,7 @@ class Transformer(nn.Module):
norm=decoder_norm, norm=decoder_norm,
) )
self.decoder_output_layer = torch.nn.Linear( self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
d_model, self.decoder_num_class
)
self.decoder_criterion = LabelSmoothingLoss() self.decoder_criterion = LabelSmoothingLoss()
else: else:
@ -183,9 +181,7 @@ class Transformer(nn.Module):
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x) x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder( encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
x, supervision
)
x = self.ctc_output(encoder_memory) x = self.ctc_output(encoder_memory)
return x, encoder_memory, memory_key_padding_mask return x, encoder_memory, memory_key_padding_mask
@ -266,23 +262,17 @@ class Transformer(nn.Module):
""" """
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence( ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out] ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence( ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device device = memory.device
ys_in_pad = ys_in_pad.to(device) ys_in_pad = ys_in_pad.to(device)
ys_out_pad = ys_out_pad.to(device) ys_out_pad = ys_out_pad.to(device)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
device
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask # TODO: Use length information to create the decoder padding mask
@ -343,23 +333,17 @@ class Transformer(nn.Module):
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence( ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out] ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence( ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
device
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask # TODO: Use length information to create the decoder padding mask
@ -632,9 +616,7 @@ def _get_activation_fn(activation: str):
elif activation == "gelu": elif activation == "gelu":
return nn.functional.gelu return nn.functional.gelu
raise RuntimeError( raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
"activation should be relu/gelu, not {}".format(activation)
)
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
@ -836,9 +818,7 @@ def encoder_padding_mask(
1, 1,
).to(torch.int32) ).to(torch.int32)
lengths = [ lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
]
for idx in range(supervision_segments.size(0)): for idx in range(supervision_segments.size(0)):
# Note: TorchScript doesn't allow to unpack tensors as tuples # Note: TorchScript doesn't allow to unpack tensors as tuples
sequence_idx = supervision_segments[idx, 0].item() sequence_idx = supervision_segments[idx, 0].item()
@ -859,9 +839,7 @@ def encoder_padding_mask(
return mask return mask
def decoder_padding_mask( def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
ys_pad: torch.Tensor, ignore_id: int = -1
) -> torch.Tensor:
"""Generate a length mask for input. """Generate a length mask for input.
The masked position are filled with True, The masked position are filled with True,

View File

@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
) )
if "train" in partition: if "train" in partition:
cut_set = ( cut_set = (
cut_set cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
) )
cut_set = cut_set.compute_and_store_features( cut_set = cut_set.compute_and_store_features(
extractor=extractor, extractor=extractor,
@ -116,9 +114,7 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -83,9 +83,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
) )
if "train" in partition: if "train" in partition:
cut_set = ( cut_set = (
cut_set cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
) )
cut_set = cut_set.compute_and_store_features( cut_set = cut_set.compute_and_store_features(
extractor=extractor, extractor=extractor,
@ -111,9 +109,7 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state cur_state = loop_state
word = word2id[word] word = word2id[word]
pieces = [ pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1): for i in range(len(pieces) - 1):
w = word if i == 0 else eps w = word if i == 0 else eps
@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False return False
def generate_lexicon( def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table. """Generate a lexicon from a word list and token_sym_table.
Args: Args:

View File

@ -317,9 +317,7 @@ def lexicon_to_fst(
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
return parser.parse_args() return parser.parse_args()

View File

@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L") fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst( fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig") fsa_disambig.draw("L_disambig.pdf", title="L_disambig")

View File

@ -76,11 +76,7 @@ from beam_search import (
) )
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -188,8 +184,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -249,9 +244,7 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -263,10 +256,7 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -387,9 +377,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -415,9 +403,7 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append( results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -428,8 +414,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
@ -473,9 +458,7 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -504,8 +487,7 @@ def main():
] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg: elif len(filenames) < params.avg:
raise ValueError( raise ValueError(

View File

@ -50,11 +50,7 @@ from pathlib import Path
import torch import torch
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import str2bool from icefall.utils import str2bool
@ -120,8 +116,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -157,8 +152,7 @@ def main():
] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg: elif len(filenames) < params.avg:
raise ValueError( raise ValueError(
@ -191,9 +185,7 @@ def main():
model.__class__.forward = torch.jit.ignore(model.__class__.forward) model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = ( filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
)
model.save(str(filename)) model.save(str(filename))
logging.info(f"Saved to {filename}") logging.info(f"Saved to {filename}")
else: else:
@ -201,17 +193,14 @@ def main():
# Save it using a format so that it can be loaded # Save it using a format so that it can be loaded
# by :func:`load_checkpoint` # by :func:`load_checkpoint`
filename = ( filename = (
params.exp_dir params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
/ f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
) )
torch.save({"model": model.state_dict()}, str(filename)) torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}") logging.info(f"Saved to {filename}")
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -165,8 +165,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -197,8 +196,7 @@ def read_sound_files(
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. " f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
f"Given: {sample_rate}"
) )
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
@ -256,13 +254,9 @@ def main():
feature_lens = [f.size(0) for f in features] feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyp_list = [] hyp_list = []
@ -310,9 +304,7 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
else: else:
raise ValueError( raise ValueError(f"Unsupported decoding method: {params.method}")
f"Unsupported decoding method: {params.method}"
)
hyp_list.append(hyp) hyp_list.append(hyp)
hyps = [] hyps = []
@ -329,9 +321,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -49,7 +49,6 @@ import optim
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AishellAsrDataModule from asr_datamodule import AishellAsrDataModule
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
@ -75,9 +74,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
@ -203,8 +200,7 @@ def get_parser():
"--initial-lr", "--initial-lr",
type=float, type=float,
default=0.003, default=0.003,
help="The initial learning rate. This value should not need " help="The initial learning rate. This value should not need " "to be changed.",
"to be changed.",
) )
parser.add_argument( parser.add_argument(
@ -227,8 +223,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -251,8 +246,7 @@ def get_parser():
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" help="The scale to smooth the loss with am (output of encoder network)" "part.",
"part.",
) )
parser.add_argument( parser.add_argument(
@ -561,11 +555,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = ( device = model.device if isinstance(model, DDP) else next(model.parameters()).device
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -593,23 +583,16 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = ( pruned_loss_scale = (
0.0 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -725,9 +708,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa except: # noqa
display_and_save_batch( display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
batch, params=params, graph_compiler=graph_compiler
)
raise raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
@ -891,7 +872,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2 ** 22 2**22
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)
@ -1029,9 +1010,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} " f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..." f"(={crit_values[criterion]}) ..."
) )
display_and_save_batch( display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
batch, params=params, graph_compiler=graph_compiler
)
raise raise

View File

@ -202,8 +202,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -263,9 +262,7 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -277,10 +274,7 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -401,9 +395,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -429,9 +421,7 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append( results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -442,8 +432,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tCER", file=f) print("settings\tCER", file=f)
@ -488,9 +477,7 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -518,9 +505,9 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg
)[: params.avg] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -551,9 +538,9 @@ def main():
) )
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"

View File

@ -132,8 +132,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -166,9 +165,9 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg
)[: params.avg] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -195,9 +194,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -252,9 +251,7 @@ def main():
model.__class__.forward = torch.jit.ignore(model.__class__.forward) model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = ( filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
)
model.save(str(filename)) model.save(str(filename))
logging.info(f"Saved to {filename}") logging.info(f"Saved to {filename}")
else: else:
@ -262,17 +259,14 @@ def main():
# Save it using a format so that it can be loaded # Save it using a format so that it can be loaded
# by :func:`load_checkpoint` # by :func:`load_checkpoint`
filename = ( filename = (
params.exp_dir params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
/ f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
) )
torch.save({"model": model.state_dict()}, str(filename)) torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}") logging.info(f"Saved to {filename}")
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -84,9 +84,7 @@ class Transducer(nn.Module):
self.decoder_datatang = decoder_datatang self.decoder_datatang = decoder_datatang
self.joiner_datatang = joiner_datatang self.joiner_datatang = joiner_datatang
self.simple_am_proj = ScaledLinear( self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
if decoder_datatang is not None: if decoder_datatang is not None:
@ -179,9 +177,7 @@ class Transducer(nn.Module):
y_padded = y.pad(mode="constant", padding_value=0) y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64) y_padded = y_padded.to(torch.int64)
boundary = torch.zeros( boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens boundary[:, 3] = encoder_out_lens

View File

@ -165,8 +165,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -197,8 +196,7 @@ def read_sound_files(
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. " f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
f"Given: {sample_rate}"
) )
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
@ -257,13 +255,9 @@ def main():
feature_lens = [f.size(0) for f in features] feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyp_list = [] hyp_list = []
@ -311,9 +305,7 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
else: else:
raise ValueError( raise ValueError(f"Unsupported decoding method: {params.method}")
f"Unsupported decoding method: {params.method}"
)
hyp_list.append(hyp) hyp_list.append(hyp)
hyps = [] hyps = []
@ -330,9 +322,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -96,9 +96,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
@ -224,8 +222,7 @@ def get_parser():
"--initial-lr", "--initial-lr",
type=float, type=float,
default=0.003, default=0.003,
help="The initial learning rate. This value should not need " help="The initial learning rate. This value should not need " "to be changed.",
"to be changed.",
) )
parser.add_argument( parser.add_argument(
@ -248,8 +245,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -272,8 +268,7 @@ def get_parser():
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" help="The scale to smooth the loss with am (output of encoder network)" "part.",
"part.",
) )
parser.add_argument( parser.add_argument(
@ -635,11 +630,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = ( device = model.device if isinstance(model, DDP) else next(model.parameters()).device
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -670,23 +661,16 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = ( pruned_loss_scale = (
0.0 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -824,9 +808,7 @@ def train_one_epoch(
) )
# summary stats # summary stats
if datatang_train_dl is not None: if datatang_train_dl is not None:
tot_loss = ( tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
tot_loss * (1 - 1 / params.reset_interval)
) + loss_info
if aishell: if aishell:
aishell_tot_loss = ( aishell_tot_loss = (
@ -847,9 +829,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa except: # noqa
display_and_save_batch( display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
batch, params=params, graph_compiler=graph_compiler
)
raise raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
@ -892,9 +872,7 @@ def train_one_epoch(
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]
if datatang_train_dl is not None: if datatang_train_dl is not None:
datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], " datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], "
tot_loss_str = ( tot_loss_str = f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
)
else: else:
tot_loss_str = "" tot_loss_str = ""
datatang_str = "" datatang_str = ""
@ -1067,7 +1045,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2 ** 22 2**22
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)
@ -1076,9 +1054,7 @@ def run(rank, world_size, args):
train_cuts = filter_short_and_long_utterances(train_cuts) train_cuts = filter_short_and_long_utterances(train_cuts)
if args.enable_musan: if args.enable_musan:
cuts_musan = load_manifest( cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
)
else: else:
cuts_musan = None cuts_musan = None
@ -1093,9 +1069,7 @@ def run(rank, world_size, args):
if params.datatang_prob > 0: if params.datatang_prob > 0:
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir) datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
train_datatang_cuts = datatang.train_cuts() train_datatang_cuts = datatang.train_cuts()
train_datatang_cuts = filter_short_and_long_utterances( train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts)
train_datatang_cuts
)
train_datatang_cuts = train_datatang_cuts.repeat(times=None) train_datatang_cuts = train_datatang_cuts.repeat(times=None)
datatang_train_dl = asr_datamodule.train_dataloaders( datatang_train_dl = asr_datamodule.train_dataloaders(
train_datatang_cuts, train_datatang_cuts,
@ -1249,9 +1223,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} " f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..." f"(={crit_values[criterion]}) ..."
) )
display_and_save_batch( display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
batch, params=params, graph_compiler=graph_compiler
)
raise raise

View File

@ -183,17 +183,13 @@ class AishellAsrDataModule:
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest( cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = [] transforms = []
if self.args.enable_musan: if self.args.enable_musan:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
transforms.append( transforms.append(
CutMix( CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
@ -215,9 +211,7 @@ class AishellAsrDataModule:
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
logging.info( logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version. # Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is # In different Lhotse's versions, the default of num_frame_masks is
# different. # different.
@ -260,9 +254,7 @@ class AishellAsrDataModule:
# Drop feats to be on the safe side. # Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -308,9 +300,7 @@ class AishellAsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:
@ -366,13 +356,9 @@ class AishellAsrDataModule:
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts") logging.info("About to get dev cuts")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz")
self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz"
)
@lru_cache() @lru_cache()
def test_cuts(self) -> List[CutSet]: def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts") logging.info("About to get test cuts")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz")
self.args.manifest_dir / "aishell_cuts_test.jsonl.gz"
)

View File

@ -265,9 +265,7 @@ def decode_dataset(
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -289,9 +287,7 @@ def save_results(
# We compute CER for aishell dataset. # We compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append( results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results_char) wer = write_error_stats(f, f"{test_set_name}-{key}", results_char)
test_set_wers[key] = wer test_set_wers[key] = wer
@ -335,9 +331,7 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device) HLG = HLG.to(device)
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -362,9 +356,7 @@ def main():
if params.export: if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save( torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
model.to(device) model.to(device)
model.eval() model.eval()
@ -392,9 +384,7 @@ def main():
lexicon=lexicon, lexicon=lexicon,
) )
save_results( save_results(params=params, test_set_name=test_set, results_dict=results_dict)
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!") logging.info("Done!")

View File

@ -66,10 +66,7 @@ class TdnnLstm(nn.Module):
nn.BatchNorm1d(num_features=500, affine=False), nn.BatchNorm1d(num_features=500, affine=False),
) )
self.lstms = nn.ModuleList( self.lstms = nn.ModuleList(
[ [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)]
nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
for _ in range(5)
]
) )
self.lstm_bnorms = nn.ModuleList( self.lstm_bnorms = nn.ModuleList(
[nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)] [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]

View File

@ -53,9 +53,7 @@ def get_parser():
help="Path to words.txt", help="Path to words.txt",
) )
parser.add_argument( parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
"--HLG", type=str, required=True, help="Path to HLG.pt."
)
parser.add_argument( parser.add_argument(
"--method", "--method",
@ -113,8 +111,7 @@ def read_sound_files(
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. " f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
f"Given: {sample_rate}"
) )
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
@ -173,9 +170,7 @@ def main():
logging.info("Decoding started") logging.info("Decoding started")
features = fbank(waves) features = fbank(waves)
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
features = features.permute(0, 2, 1) # now features is [N, C, T] features = features.permute(0, 2, 1) # now features is [N, C, T]
with torch.no_grad(): with torch.no_grad():
@ -219,9 +214,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -49,12 +49,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool
AttributeDict,
encode_supervisions,
setup_logger,
str2bool,
)
def get_parser(): def get_parser():

View File

@ -47,9 +47,9 @@ def greedy_search(
device = model.device device = model.device
decoder_input = torch.tensor( decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
[blank_id] * context_size, device=device 1, context_size
).reshape(1, context_size) )
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
@ -81,9 +81,9 @@ def greedy_search(
y = logits.argmax().item() y = logits.argmax().item()
if y != blank_id: if y != blank_id:
hyp.append(y) hyp.append(y)
decoder_input = torch.tensor( decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
[hyp[-context_size:]], device=device 1, context_size
).reshape(1, context_size) )
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
@ -157,9 +157,7 @@ class HypothesisList(object):
""" """
if length_norm: if length_norm:
return max( return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
)
else: else:
return max(self._data.values(), key=lambda hyp: hyp.log_prob) return max(self._data.values(), key=lambda hyp: hyp.log_prob)
@ -246,9 +244,9 @@ def beam_search(
device = model.device device = model.device
decoder_input = torch.tensor( decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
[blank_id] * context_size, device=device 1, context_size
).reshape(1, context_size) )
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)

View File

@ -155,9 +155,7 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True, normalize_before: bool = True,
) -> None: ) -> None:
super(ConformerEncoderLayer, self).__init__() super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention( self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
d_model, nhead, dropout=0.0
)
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
@ -175,18 +173,14 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm( self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
d_model
) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5 self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm( self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
d_model
) # for the final output of the block
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
@ -220,9 +214,7 @@ class ConformerEncoderLayer(nn.Module):
residual = src residual = src
if self.normalize_before: if self.normalize_before:
src = self.norm_ff_macaron(src) src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout( src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
self.feed_forward_macaron(src)
)
if not self.normalize_before: if not self.normalize_before:
src = self.norm_ff_macaron(src) src = self.norm_ff_macaron(src)
@ -341,9 +333,7 @@ class RelPositionalEncoding(torch.nn.Module):
""" """
def __init__( def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
"""Construct an PositionalEncoding object.""" """Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__() super(RelPositionalEncoding, self).__init__()
self.d_model = d_model self.d_model = d_model
@ -359,9 +349,7 @@ class RelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1 # the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1: if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device # Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str( if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device) self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return return
# Suppose `i` means to the position of query vector and `j` means the # Suppose `i` means to the position of query vector and `j` means the
@ -631,9 +619,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value): if torch.equal(query, key) and torch.equal(key, value):
# self-attention # self-attention
q, k, v = nn.functional.linear( q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
query, in_proj_weight, in_proj_bias 3, dim=-1
).chunk(3, dim=-1) )
elif torch.equal(key, value): elif torch.equal(key, value):
# encoder-decoder attention # encoder-decoder attention
@ -701,31 +689,22 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0) attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError( raise RuntimeError("The size of the 2D attn_mask is not correct.")
"The size of the 2D attn_mask is not correct."
)
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [ if list(attn_mask.size()) != [
bsz * num_heads, bsz * num_heads,
query.size(0), query.size(0),
key.size(0), key.size(0),
]: ]:
raise RuntimeError( raise RuntimeError("The size of the 3D attn_mask is not correct.")
"The size of the 3D attn_mask is not correct."
)
else: else:
raise RuntimeError( raise RuntimeError(
"attn_mask's dimension {} is not supported".format( "attn_mask's dimension {} is not supported".format(attn_mask.dim())
attn_mask.dim()
)
) )
# attn_mask's dim is 3 now. # attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool # convert ByteTensor key_padding_mask to bool
if ( if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
key_padding_mask is not None
and key_padding_mask.dtype == torch.uint8
):
warnings.warn( warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
) )
@ -764,9 +743,7 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c # first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul( matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d # compute matrix b and matrix d
matrix_bd = torch.matmul( matrix_bd = torch.matmul(
@ -778,9 +755,7 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2) ) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
bsz * num_heads, tgt_len, -1
)
assert list(attn_output_weights.size()) == [ assert list(attn_output_weights.size()) == [
bsz * num_heads, bsz * num_heads,
@ -814,13 +789,9 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = torch.bmm(attn_output_weights, v) attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = ( attn_output = (
attn_output.transpose(0, 1) attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
) )
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights: if need_weights:
# average attention weights over heads # average attention weights over heads
@ -843,9 +814,7 @@ class ConvolutionModule(nn.Module):
""" """
def __init__( def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
self, channels: int, kernel_size: int, bias: bool = True
) -> None:
"""Construct an ConvolutionModule object.""" """Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__() super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding # kernerl_size should be a odd number for 'SAME' padding

View File

@ -99,8 +99,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -227,9 +226,7 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
x=feature, x_lens=feature_lens
)
hyps = [] hyps = []
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -248,9 +245,7 @@ def decode_one_batch(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model, encoder_out=encoder_out_i, beam=params.beam_size
) )
else: else:
raise ValueError( raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[i] for i in hyp]) hyps.append([lexicon.token_table[i] for i in hyp])
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
@ -319,9 +314,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -346,9 +339,7 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append( results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -359,8 +350,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tCER", file=f) print("settings\tCER", file=f)
@ -430,9 +420,7 @@ def main():
if params.export: if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save( torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return return
model.to(device) model.to(device)

View File

@ -86,9 +86,7 @@ class Decoder(nn.Module):
if self.context_size > 1: if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True: if need_pad is True:
embedding_out = F.pad( embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
embedding_out, pad=(self.context_size - 1, 0)
)
else: else:
# During inference time, there is no need to do extra padding # During inference time, there is no need to do extra padding
# as we only need one output # as we only need one output

View File

@ -110,8 +110,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
return parser return parser
@ -243,9 +242,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -103,9 +103,7 @@ class Transducer(nn.Module):
y_padded = y.pad(mode="constant", padding_value=0) y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64) y_padded = y_padded.to(torch.int64)
boundary = torch.zeros( boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens boundary[:, 2] = y_lens
boundary[:, 3] = x_lens boundary[:, 3] = x_lens

View File

@ -117,8 +117,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -212,8 +211,7 @@ def read_sound_files(
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. " f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
f"Given: {sample_rate}"
) )
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
@ -273,9 +271,7 @@ def main():
features = fbank(waves) features = fbank(waves)
feature_lengths = [f.size(0) for f in features] feature_lengths = [f.size(0) for f in features]
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
@ -319,9 +315,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -126,8 +126,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -389,9 +388,7 @@ def compute_loss(
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -504,9 +501,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -625,9 +620,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if rank == 0:

View File

@ -250,9 +250,7 @@ def _get_activation_fn(activation: str):
elif activation == "gelu": elif activation == "gelu":
return nn.functional.gelu return nn.functional.gelu
raise RuntimeError( raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
"activation should be relu/gelu, not {}".format(activation)
)
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):

View File

@ -29,10 +29,7 @@ from lhotse.dataset import (
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import ( from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures
OnTheFlyFeatures,
PrecomputedFeatures,
)
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from icefall.utils import str2bool from icefall.utils import str2bool
@ -162,9 +159,7 @@ class AsrDataModule:
if cuts_musan is not None: if cuts_musan is not None:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
transforms.append( transforms.append(
CutMix( CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
@ -173,9 +168,7 @@ class AsrDataModule:
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
logging.info( logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version. # Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is # In different Lhotse's versions, the default of num_frame_masks is
# different. # different.
@ -252,9 +245,7 @@ class AsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:

View File

@ -170,8 +170,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -227,9 +226,7 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -241,10 +238,7 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -365,9 +359,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -393,9 +385,7 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append( results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -406,8 +396,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tCER", file=f) print("settings\tCER", file=f)
@ -448,9 +437,7 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"

View File

@ -109,8 +109,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
return parser return parser
@ -241,9 +240,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -165,8 +165,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -195,8 +194,7 @@ def read_sound_files(
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. " f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
f"Given: {sample_rate}"
) )
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
@ -254,13 +252,9 @@ def main():
feature_lens = [f.size(0) for f in features] feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyp_list = [] hyp_list = []
@ -308,9 +302,7 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
else: else:
raise ValueError( raise ValueError(f"Unsupported decoding method: {params.method}")
f"Unsupported decoding method: {params.method}"
)
hyp_list.append(hyp) hyp_list.append(hyp)
hyps = [] hyps = []
@ -327,9 +319,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -149,8 +149,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -168,8 +167,7 @@ def get_parser():
"--datatang-prob", "--datatang-prob",
type=float, type=float,
default=0.2, default=0.2,
help="The probability to select a batch from the " help="The probability to select a batch from the " "aidatatang_200zh dataset",
"aidatatang_200zh dataset",
) )
return parser return parser
@ -449,9 +447,7 @@ def compute_loss(
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -605,9 +601,7 @@ def train_one_epoch(
f"train/current_{prefix}_", f"train/current_{prefix}_",
params.batch_idx_train, params.batch_idx_train,
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
aishell_tot_loss.write_summary( aishell_tot_loss.write_summary(
tb_writer, "train/aishell_tot_", params.batch_idx_train tb_writer, "train/aishell_tot_", params.batch_idx_train
) )
@ -735,9 +729,7 @@ def run(rank, world_size, args):
train_datatang_cuts = train_datatang_cuts.repeat(times=None) train_datatang_cuts = train_datatang_cuts.repeat(times=None)
if args.enable_musan: if args.enable_musan:
cuts_musan = load_manifest( cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
)
else: else:
cuts_musan = None cuts_musan = None
@ -776,9 +768,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if rank == 0:

View File

@ -171,8 +171,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -231,9 +230,7 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -245,10 +242,7 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -369,9 +363,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -397,9 +389,7 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append( results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -410,8 +400,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tCER", file=f) print("settings\tCER", file=f)
@ -452,9 +441,7 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"

View File

@ -109,8 +109,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
return parser return parser
@ -241,9 +240,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -165,8 +165,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -195,8 +194,7 @@ def read_sound_files(
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. " f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
f"Given: {sample_rate}"
) )
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
@ -254,13 +252,9 @@ def main():
feature_lens = [f.size(0) for f in features] feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyp_list = [] hyp_list = []
@ -308,9 +302,7 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
else: else:
raise ValueError( raise ValueError(f"Unsupported decoding method: {params.method}")
f"Unsupported decoding method: {params.method}"
)
hyp_list.append(hyp) hyp_list.append(hyp)
hyps = [] hyps = []
@ -327,9 +319,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -142,8 +142,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -414,9 +413,7 @@ def compute_loss(
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -529,9 +526,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -657,9 +652,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if rank == 0:

0
egs/aishell2/ASR/local/__init__.py Executable file → Normal file
View File

View File

@ -83,9 +83,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
) )
if "train" in partition: if "train" in partition:
cut_set = ( cut_set = (
cut_set cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
) )
cut_set = cut_set.compute_and_store_features( cut_set = cut_set.compute_and_store_features(
extractor=extractor, extractor=extractor,
@ -111,9 +109,7 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

View File

@ -216,13 +216,9 @@ class AiShell2AsrDataModule:
if self.args.enable_musan: if self.args.enable_musan:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest( cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms.append( transforms.append(
CutMix( CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
@ -244,9 +240,7 @@ class AiShell2AsrDataModule:
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
logging.info( logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version. # Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is # In different Lhotse's versions, the default of num_frame_masks is
# different. # different.
@ -290,9 +284,7 @@ class AiShell2AsrDataModule:
# Drop feats to be on the safe side. # Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -348,9 +340,7 @@ class AiShell2AsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:
@ -406,9 +396,7 @@ class AiShell2AsrDataModule:
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def valid_cuts(self) -> CutSet:
logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz") logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz")
self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz"
)
@lru_cache() @lru_cache()
def test_cuts(self) -> CutSet: def test_cuts(self) -> CutSet:

View File

@ -269,8 +269,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -348,9 +347,7 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
x=feature, x_lens=feature_lens
)
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
@ -409,10 +406,7 @@ def decode_one_batch(
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -538,9 +532,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -573,8 +565,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
@ -625,9 +616,7 @@ def main():
if "LG" in params.decoding_method: if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -661,9 +650,9 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg
)[: params.avg] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -690,9 +679,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -749,9 +738,7 @@ def main():
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:
decoding_graph = k2.trivial_graph( decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
params.vocab_size - 1, device=device
)
else: else:
decoding_graph = None decoding_graph = None

View File

@ -133,8 +133,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -167,9 +166,9 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg
)[: params.avg] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -196,9 +195,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -266,9 +265,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -159,8 +159,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -192,8 +191,7 @@ def read_sound_files(
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. " f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
f"Given: {sample_rate}"
) )
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
@ -254,15 +252,11 @@ def main():
features = fbank(waves) features = fbank(waves)
feature_lengths = [f.size(0) for f in features] feature_lengths = [f.size(0) for f in features]
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyps = [] hyps = []
@ -334,9 +328,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -92,9 +92,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
@ -220,8 +218,7 @@ def get_parser():
"--initial-lr", "--initial-lr",
type=float, type=float,
default=0.003, default=0.003,
help="The initial learning rate. This value should not need " help="The initial learning rate. This value should not need " "to be changed.",
"to be changed.",
) )
parser.add_argument( parser.add_argument(
@ -244,8 +241,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -268,8 +264,7 @@ def get_parser():
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" help="The scale to smooth the loss with am (output of encoder network)" "part.",
"part.",
) )
parser.add_argument( parser.add_argument(
@ -603,11 +598,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = ( device = model.device if isinstance(model, DDP) else next(model.parameters()).device
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -636,23 +627,16 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = ( pruned_loss_scale = (
0.0 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -771,9 +755,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa except: # noqa
display_and_save_batch( display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
batch, params=params, graph_compiler=graph_compiler
)
raise raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
@ -829,9 +811,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -939,7 +919,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2 ** 22 2**22
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)
@ -1104,9 +1084,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} " f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..." f"(={crit_values[criterion]}) ..."
) )
display_and_save_batch( display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
batch, params=params, graph_compiler=graph_compiler
)
raise raise

View File

@ -85,9 +85,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
) )
if "train" in partition: if "train" in partition:
cut_set = ( cut_set = (
cut_set cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
) )
cut_set = cut_set.compute_and_store_features( cut_set = cut_set.compute_and_store_features(
extractor=extractor, extractor=extractor,
@ -120,9 +118,7 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state cur_state = loop_state
word = word2id[word] word = word2id[word]
pieces = [ pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1): for i in range(len(pieces) - 1):
w = word if i == 0 else eps w = word if i == 0 else eps
@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False return False
def generate_lexicon( def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table. """Generate a lexicon from a word list and token_sym_table.
Args: Args:

View File

@ -317,9 +317,7 @@ def lexicon_to_fst(
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
return parser.parse_args() return parser.parse_args()

View File

@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L") fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst( fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig") fsa_disambig.draw("L_disambig.pdf", title="L_disambig")

View File

@ -56,9 +56,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns" "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
) )
parser.add_argument( parser.add_argument("--space", default="<space>", type=str, help="space symbol")
"--space", default="<space>", type=str, help="space symbol"
)
parser.add_argument( parser.add_argument(
"--non-lang-syms", "--non-lang-syms",
"-l", "-l",
@ -66,9 +64,7 @@ def get_parser():
type=str, type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.", help="list of non-linguistic symobles, e.g., <NOISE> etc.",
) )
parser.add_argument( parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
"text", type=str, default=False, nargs="?", help="input text"
)
parser.add_argument( parser.add_argument(
"--trans_type", "--trans_type",
"-t", "-t",
@ -108,8 +104,7 @@ def token2id(
if token_type == "lazy_pinyin": if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list) text = lazy_pinyin(chars_list)
sub_ids = [ sub_ids = [
token_table[txt] if txt in token_table else oov_id token_table[txt] if txt in token_table else oov_id for txt in text
for txt in text
] ]
ids.append(sub_ids) ids.append(sub_ids)
else: # token_type = "pinyin" else: # token_type = "pinyin"
@ -135,9 +130,7 @@ def main():
if args.text: if args.text:
f = codecs.open(args.text, encoding="utf-8") f = codecs.open(args.text, encoding="utf-8")
else: else:
f = codecs.getreader("utf-8")( f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
sys.stdin if is_python2 else sys.stdin.buffer
)
sys.stdout = codecs.getwriter("utf-8")( sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer sys.stdout if is_python2 else sys.stdout.buffer

View File

@ -222,17 +222,13 @@ class Aishell4AsrDataModule:
The state dict for the training sampler. The state dict for the training sampler.
""" """
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest( cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = [] transforms = []
if self.args.enable_musan: if self.args.enable_musan:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
transforms.append( transforms.append(
CutMix( CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
@ -254,9 +250,7 @@ class Aishell4AsrDataModule:
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
logging.info( logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version. # Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is # In different Lhotse's versions, the default of num_frame_masks is
# different. # different.
@ -300,9 +294,7 @@ class Aishell4AsrDataModule:
# Drop feats to be on the safe side. # Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -359,9 +351,7 @@ class Aishell4AsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:

View File

@ -201,8 +201,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -260,9 +259,7 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
x=feature, x_lens=feature_lens
)
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
@ -277,10 +274,7 @@ def decode_one_batch(
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -401,9 +395,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -436,8 +428,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
@ -480,9 +471,7 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -510,9 +499,9 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg
)[: params.avg] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -543,9 +532,9 @@ def main():
) )
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"

View File

@ -136,8 +136,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -169,9 +168,9 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg
)[: params.avg] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -202,9 +201,9 @@ def main():
) )
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -276,9 +275,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -172,8 +172,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -205,8 +204,7 @@ def read_sound_files(
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. " f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
f"Given: {sample_rate}"
) )
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
@ -266,15 +264,11 @@ def main():
features = fbank(waves) features = fbank(waves)
feature_lengths = [f.size(0) for f in features] feature_lengths = [f.size(0) for f in features]
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyps = [] hyps = []
@ -306,10 +300,7 @@ def main():
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -350,9 +341,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -85,9 +85,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
@ -213,8 +211,7 @@ def get_parser():
"--initial-lr", "--initial-lr",
type=float, type=float,
default=0.003, default=0.003,
help="The initial learning rate. This value should not need " help="The initial learning rate. This value should not need " "to be changed.",
"to be changed.",
) )
parser.add_argument( parser.add_argument(
@ -237,8 +234,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -261,8 +257,7 @@ def get_parser():
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" help="The scale to smooth the loss with am (output of encoder network)" "part.",
"part.",
) )
parser.add_argument( parser.add_argument(
@ -599,11 +594,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = ( device = model.device if isinstance(model, DDP) else next(model.parameters()).device
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -633,22 +624,15 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = ( pruned_loss_scale = (
0.0 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -827,9 +811,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -937,7 +919,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2 ** 22 2**22
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -84,9 +84,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
) )
if "train" in partition: if "train" in partition:
cut_set = ( cut_set = (
cut_set cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
) )
cur_num_jobs = num_jobs if ex is None else 80 cur_num_jobs = num_jobs if ex is None else 80
cur_num_jobs = min(cur_num_jobs, len(cut_set)) cur_num_jobs = min(cur_num_jobs, len(cut_set))
@ -121,9 +119,7 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state cur_state = loop_state
word = word2id[word] word = word2id[word]
pieces = [ pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1): for i in range(len(pieces) - 1):
w = word if i == 0 else eps w = word if i == 0 else eps
@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False return False
def generate_lexicon( def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table. """Generate a lexicon from a word list and token_sym_table.
Args: Args:

View File

@ -317,9 +317,7 @@ def lexicon_to_fst(
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
return parser.parse_args() return parser.parse_args()

View File

@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L") fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst( fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig") fsa_disambig.draw("L_disambig.pdf", title="L_disambig")

View File

@ -30,8 +30,8 @@ with word segmenting:
import argparse import argparse
import paddle
import jieba import jieba
import paddle
from tqdm import tqdm from tqdm import tqdm
paddle.enable_static() paddle.enable_static()

View File

@ -56,9 +56,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns" "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
) )
parser.add_argument( parser.add_argument("--space", default="<space>", type=str, help="space symbol")
"--space", default="<space>", type=str, help="space symbol"
)
parser.add_argument( parser.add_argument(
"--non-lang-syms", "--non-lang-syms",
"-l", "-l",
@ -66,9 +64,7 @@ def get_parser():
type=str, type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.", help="list of non-linguistic symobles, e.g., <NOISE> etc.",
) )
parser.add_argument( parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
"text", type=str, default=False, nargs="?", help="input text"
)
parser.add_argument( parser.add_argument(
"--trans_type", "--trans_type",
"-t", "-t",
@ -108,8 +104,7 @@ def token2id(
if token_type == "lazy_pinyin": if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list) text = lazy_pinyin(chars_list)
sub_ids = [ sub_ids = [
token_table[txt] if txt in token_table else oov_id token_table[txt] if txt in token_table else oov_id for txt in text
for txt in text
] ]
ids.append(sub_ids) ids.append(sub_ids)
else: # token_type = "pinyin" else: # token_type = "pinyin"
@ -135,9 +130,7 @@ def main():
if args.text: if args.text:
f = codecs.open(args.text, encoding="utf-8") f = codecs.open(args.text, encoding="utf-8")
else: else:
f = codecs.getreader("utf-8")( f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
sys.stdin if is_python2 else sys.stdin.buffer
)
sys.stdout = codecs.getwriter("utf-8")( sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer sys.stdout if is_python2 else sys.stdout.buffer

View File

@ -205,17 +205,13 @@ class AlimeetingAsrDataModule:
The state dict for the training sampler. The state dict for the training sampler.
""" """
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest( cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = [] transforms = []
if self.args.enable_musan: if self.args.enable_musan:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
transforms.append( transforms.append(
CutMix( CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
@ -237,9 +233,7 @@ class AlimeetingAsrDataModule:
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
logging.info( logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version. # Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is # In different Lhotse's versions, the default of num_frame_masks is
# different. # different.
@ -282,9 +276,7 @@ class AlimeetingAsrDataModule:
# Drop feats to be on the safe side. # Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -341,9 +333,7 @@ class AlimeetingAsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:

View File

@ -70,11 +70,7 @@ from beam_search import (
from lhotse.cut import Cut from lhotse.cut import Cut
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -193,8 +189,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -249,9 +244,7 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
x=feature, x_lens=feature_lens
)
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
@ -266,10 +259,7 @@ def decode_one_batch(
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -390,9 +380,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -425,8 +413,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
@ -563,8 +550,7 @@ def main():
) )
dev_shards = [ dev_shards = [
str(path) str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
] ]
cuts_dev_webdataset = CutSet.from_webdataset( cuts_dev_webdataset = CutSet.from_webdataset(
dev_shards, dev_shards,
@ -574,8 +560,7 @@ def main():
) )
test_shards = [ test_shards = [
str(path) str(path) for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
] ]
cuts_test_webdataset = CutSet.from_webdataset( cuts_test_webdataset = CutSet.from_webdataset(
test_shards, test_shards,
@ -588,9 +573,7 @@ def main():
return 1.0 <= c.duration return 1.0 <= c.duration
cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt) cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt)
cuts_test_webdataset = cuts_test_webdataset.filter( cuts_test_webdataset = cuts_test_webdataset.filter(remove_short_and_long_utt)
remove_short_and_long_utt
)
dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset) dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset)
test_dl = alimeeting.test_dataloaders(cuts_test_webdataset) test_dl = alimeeting.test_dataloaders(cuts_test_webdataset)

View File

@ -103,8 +103,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
return parser return parser
@ -173,9 +172,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

Some files were not shown because too many files have changed in this diff Show More