mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Apply delay penalty on k2 ctc loss (#669)
* add init files * fix bug, apply delay penalty * fix decoding code and getting timestamps * add option applying delay penalty on ctc log-prob * fix bug of streaming decoding * minor change for bpe-based case * add test_model.py * add README.md * add CI
This commit is contained in:
parent
6693d907d3
commit
ece728d895
2
.flake8
2
.flake8
@ -11,7 +11,7 @@ per-file-ignores =
|
|||||||
egs/*/ASR/*/scaling.py: E501,
|
egs/*/ASR/*/scaling.py: E501,
|
||||||
egs/librispeech/ASR/lstm_transducer_stateless*/*.py: E501, E203
|
egs/librispeech/ASR/lstm_transducer_stateless*/*.py: E501, E203
|
||||||
egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203
|
egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203
|
||||||
egs/librispeech/ASR/conformer_ctc2/*py: E501,
|
egs/librispeech/ASR/conformer_ctc*/*py: E501,
|
||||||
egs/librispeech/ASR/RESULTS.md: E999,
|
egs/librispeech/ASR/RESULTS.md: E999,
|
||||||
|
|
||||||
# invalid escape sequence (cause by tex formular), W605
|
# invalid escape sequence (cause by tex formular), W605
|
||||||
|
119
.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
vendored
Executable file
119
.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
vendored
Executable file
@ -0,0 +1,119 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conformer-ctc3-2022-11-27
|
||||||
|
|
||||||
|
log "Downloading pre-trained model from $repo_url"
|
||||||
|
git lfs install
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
log "Display test files"
|
||||||
|
tree $repo/
|
||||||
|
soxi $repo/test_wavs/*.wav
|
||||||
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
|
pushd $repo/exp
|
||||||
|
git lfs pull --include "data/*"
|
||||||
|
git lfs pull --include "exp/jit_trace.pt"
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
ls -lh *.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
log "Decode with models exported by torch.jit.trace()"
|
||||||
|
|
||||||
|
for m in ctc-decoding 1best; do
|
||||||
|
./conformer_ctc3/jit_pretrained.py \
|
||||||
|
--model-filename $repo/exp/jit_trace.pt \
|
||||||
|
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||||
|
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--G $repo/data/lm/G_4_gram.pt \
|
||||||
|
--method $m \
|
||||||
|
--sample-rate 16000 \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
done
|
||||||
|
|
||||||
|
log "Export to torchscript model"
|
||||||
|
|
||||||
|
./conformer_ctc3/export.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--lang-dir $repo/data/lang_bpe_500 \
|
||||||
|
--jit-trace 1 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--use-averaged-model 0
|
||||||
|
|
||||||
|
ls -lh $repo/exp/*.pt
|
||||||
|
|
||||||
|
log "Decode with models exported by torch.jit.trace()"
|
||||||
|
|
||||||
|
for m in ctc-decoding 1best; do
|
||||||
|
./conformer_ctc3/jit_pretrained.py \
|
||||||
|
--model-filename $repo/exp/jit_trace.pt \
|
||||||
|
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||||
|
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--G $repo/data/lm/G_4_gram.pt \
|
||||||
|
--method $m \
|
||||||
|
--sample-rate 16000 \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
done
|
||||||
|
|
||||||
|
for m in ctc-decoding 1best; do
|
||||||
|
./conformer_ctc3/pretrained.py \
|
||||||
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
|
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||||
|
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--G $repo/data/lm/G_4_gram.pt \
|
||||||
|
--method $m \
|
||||||
|
--sample-rate 16000 \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
||||||
|
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
||||||
|
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
|
||||||
|
mkdir -p conformer_ctc3/exp
|
||||||
|
ln -s $PWD/$repo/exp/pretrained.pt conformer_ctc3/exp/epoch-999.pt
|
||||||
|
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
||||||
|
|
||||||
|
ls -lh data
|
||||||
|
ls -lh conformer_ctc3/exp
|
||||||
|
|
||||||
|
log "Decoding test-clean and test-other"
|
||||||
|
|
||||||
|
# use a small value for decoding with CPU
|
||||||
|
max_duration=100
|
||||||
|
|
||||||
|
for method in ctc-decoding 1best; do
|
||||||
|
log "Decoding with $method"
|
||||||
|
./conformer_ctc3/decode.py \
|
||||||
|
--epoch 999 \
|
||||||
|
--avg 1 \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--exp-dir conformer_ctc3/exp/ \
|
||||||
|
--max-duration $max_duration \
|
||||||
|
--decoding-method $method \
|
||||||
|
--lm-dir data/lm
|
||||||
|
done
|
||||||
|
|
||||||
|
rm conformer_ctc3/exp/*.pt
|
||||||
|
fi
|
151
.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml
vendored
Normal file
151
.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml
vendored
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
|
||||||
|
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
name: run-librispeech-conformer-ctc3-2022-11-28
|
||||||
|
# zipformer
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
pull_request:
|
||||||
|
types: [labeled]
|
||||||
|
|
||||||
|
schedule:
|
||||||
|
# minute (0-59)
|
||||||
|
# hour (0-23)
|
||||||
|
# day of the month (1-31)
|
||||||
|
# month (1-12)
|
||||||
|
# day of the week (0-6)
|
||||||
|
# nightly build at 15:50 UTC time every day
|
||||||
|
- cron: "50 15 * * *"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
run_librispeech_2022_11_28_conformer_ctc3:
|
||||||
|
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest]
|
||||||
|
python-version: [3.8]
|
||||||
|
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Setup Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
cache: 'pip'
|
||||||
|
cache-dependency-path: '**/requirements-ci.txt'
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
run: |
|
||||||
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
|
pip uninstall -y protobuf
|
||||||
|
pip install --no-binary protobuf protobuf
|
||||||
|
|
||||||
|
- name: Cache kaldifeat
|
||||||
|
id: my-cache
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/tmp/kaldifeat
|
||||||
|
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||||
|
|
||||||
|
- name: Install kaldifeat
|
||||||
|
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/install-kaldifeat.sh
|
||||||
|
|
||||||
|
- name: Cache LibriSpeech test-clean and test-other datasets
|
||||||
|
id: libri-test-clean-and-test-other-data
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/tmp/download
|
||||||
|
key: cache-libri-test-clean-and-test-other
|
||||||
|
|
||||||
|
- name: Download LibriSpeech test-clean and test-other
|
||||||
|
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
|
||||||
|
|
||||||
|
- name: Prepare manifests for LibriSpeech test-clean and test-other
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
|
||||||
|
|
||||||
|
- name: Cache LibriSpeech test-clean and test-other fbank features
|
||||||
|
id: libri-test-clean-and-test-other-fbank
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/tmp/fbank-libri
|
||||||
|
key: cache-libri-fbank-test-clean-and-test-other-v2
|
||||||
|
|
||||||
|
- name: Compute fbank for LibriSpeech test-clean and test-other
|
||||||
|
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
|
||||||
|
|
||||||
|
- name: Inference with pre-trained model
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||||
|
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||||
|
run: |
|
||||||
|
mkdir -p egs/librispeech/ASR/data
|
||||||
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
|
sudo apt-get -qq install git-lfs tree sox
|
||||||
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|
||||||
|
.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
|
||||||
|
|
||||||
|
- name: Display decoding results for librispeech conformer_ctc3
|
||||||
|
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cd egs/librispeech/ASR/
|
||||||
|
tree ./conformer_ctc3/exp
|
||||||
|
|
||||||
|
cd conformer_ctc3
|
||||||
|
echo "results for conformer_ctc3"
|
||||||
|
echo "===ctc-decoding==="
|
||||||
|
find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
echo "===1best==="
|
||||||
|
find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
- name: Upload decoding results for librispeech conformer_ctc3
|
||||||
|
uses: actions/upload-artifact@v2
|
||||||
|
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||||
|
with:
|
||||||
|
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-conformer_ctc3-2022-11-28
|
||||||
|
path: egs/librispeech/ASR/conformer_ctc3/exp/
|
@ -1,5 +1,106 @@
|
|||||||
## Results
|
## Results
|
||||||
|
|
||||||
|
### LibriSpeech BPE training results (Conformer CTC, supporting delay penalty)
|
||||||
|
|
||||||
|
#### [conformer_ctc3](./conformer_ctc3)
|
||||||
|
|
||||||
|
It implements Conformer model training with CTC loss.
|
||||||
|
For streaming mode, it supports symbol delay penalty.
|
||||||
|
|
||||||
|
See <https://github.com/k2-fsa/icefall/pull/669> for more details.
|
||||||
|
|
||||||
|
##### training on full librispeech
|
||||||
|
|
||||||
|
This model contains 12 encoder layers. The number of model parameters is 77352694.
|
||||||
|
|
||||||
|
The WERs are:
|
||||||
|
|
||||||
|
| | test-clean | test-other | comment |
|
||||||
|
|-------------------------------------|------------|------------|----------------------|
|
||||||
|
| ctc-decoding | 3.09 | 7.62 | --epoch 25 --avg 7 |
|
||||||
|
| 1best | 2.87 | 6.44 | --epoch 25 --avg 7 |
|
||||||
|
| nbest | 2.88 | 6.5 | --epoch 25 --avg 7 |
|
||||||
|
| nbest-rescoring | 2.71 | 6.1 | --epoch 25 --avg 7 |
|
||||||
|
| whole-lattice-rescoring | 2.71 | 6.04 | --epoch 25 --avg 7 |
|
||||||
|
|
||||||
|
The training command is:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./conformer_ctc3/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 25 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--exp-dir conformer_ctc3/full \
|
||||||
|
--full-libri 1 \
|
||||||
|
--max-duration 300 \
|
||||||
|
--master-port 12345
|
||||||
|
```
|
||||||
|
|
||||||
|
The tensorboard log can be found at
|
||||||
|
<https://tensorboard.dev/experiment/4jbxIQ2SQIaQeRqsR6bOSA>
|
||||||
|
|
||||||
|
The decoding command using different methods is:
|
||||||
|
```bash
|
||||||
|
for method in ctc-decoding 1best nbest nbest-rescoring whole-lattice-rescoring; do
|
||||||
|
./conformer_ctc3/decode.py \
|
||||||
|
--epoch 25 \
|
||||||
|
--avg 7 \
|
||||||
|
--exp-dir conformer_ctc3/exp \
|
||||||
|
--max-duration 300 \
|
||||||
|
--decoding-method $method \
|
||||||
|
--manifest-dir data/fbank \
|
||||||
|
--lm-dir data/lm \
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
Pretrained models, training logs, decoding logs, and decoding results
|
||||||
|
are available at
|
||||||
|
<https://huggingface.co/Zengwei/icefall-asr-librispeech-conformer-ctc3-2022-11-27>
|
||||||
|
|
||||||
|
The command to train a streaming model with symbol delay penalty is:
|
||||||
|
```bash
|
||||||
|
./conformer_ctc3/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 30 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--exp-dir conformer_ctc3/exp \
|
||||||
|
--full-libri 1 \
|
||||||
|
--dynamic-chunk-training 1 \
|
||||||
|
--causal-convolution 1 \
|
||||||
|
--short-chunk-size 25 \
|
||||||
|
--num-left-chunks 4 \
|
||||||
|
--max-duration 300 \
|
||||||
|
--delay-penalty 0.1
|
||||||
|
```
|
||||||
|
To evaluate symbol delay, you should:
|
||||||
|
(1) Generate cuts with word-time alignments:
|
||||||
|
```bash
|
||||||
|
./local/add_alignment_librispeech.py \
|
||||||
|
--alignments-dir data/alignment \
|
||||||
|
--cuts-in-dir data/fbank \
|
||||||
|
--cuts-out-dir data/fbank_ali
|
||||||
|
```
|
||||||
|
(2) Set the argument "--manifest-dir data/fbank_ali" while decoding.
|
||||||
|
For example:
|
||||||
|
```bash
|
||||||
|
./conformer_ctc3/decode.py \
|
||||||
|
--epoch 25 \
|
||||||
|
--avg 7 \
|
||||||
|
--exp-dir ./conformer_ctc3/exp \
|
||||||
|
--max-duration 300 \
|
||||||
|
--decoding-method ctc-decoding \
|
||||||
|
--simulate-streaming 1 \
|
||||||
|
--causal-convolution 1 \
|
||||||
|
--decode-chunk-size 16 \
|
||||||
|
--left-context 64 \
|
||||||
|
--manifest-dir data/fbank_ali
|
||||||
|
```
|
||||||
|
Note: It supports to calculate symbol delay with following decoding methods:
|
||||||
|
- ctc-greedy-search
|
||||||
|
- ctc-decoding
|
||||||
|
- 1best
|
||||||
|
|
||||||
|
|
||||||
### pruned_transducer_stateless8 (zipformer + multidataset)
|
### pruned_transducer_stateless8 (zipformer + multidataset)
|
||||||
|
|
||||||
See <https://github.com/k2-fsa/icefall/pull/675> for more details.
|
See <https://github.com/k2-fsa/icefall/pull/675> for more details.
|
||||||
@ -115,7 +216,6 @@ done
|
|||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + gradient filter)
|
### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + gradient filter)
|
||||||
|
|
||||||
#### [lstm_transducer_stateless3](./lstm_transducer_stateless3)
|
#### [lstm_transducer_stateless3](./lstm_transducer_stateless3)
|
||||||
|
1
egs/librispeech/ASR/conformer_ctc3/__init__.py
Symbolic link
1
egs/librispeech/ASR/conformer_ctc3/__init__.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless2/__init__.py
|
1
egs/librispeech/ASR/conformer_ctc3/asr_datamodule.py
Symbolic link
1
egs/librispeech/ASR/conformer_ctc3/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless2/asr_datamodule.py
|
1
egs/librispeech/ASR/conformer_ctc3/conformer.py
Symbolic link
1
egs/librispeech/ASR/conformer_ctc3/conformer.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless2/conformer.py
|
1004
egs/librispeech/ASR/conformer_ctc3/decode.py
Executable file
1004
egs/librispeech/ASR/conformer_ctc3/decode.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/librispeech/ASR/conformer_ctc3/encoder_interface.py
Symbolic link
1
egs/librispeech/ASR/conformer_ctc3/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless2/encoder_interface.py
|
292
egs/librispeech/ASR/conformer_ctc3/export.py
Executable file
292
egs/librispeech/ASR/conformer_ctc3/export.py
Executable file
@ -0,0 +1,292 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# This script converts several saved checkpoints
|
||||||
|
# to a single one using model averaging.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
(1) Export to torchscript model using torch.jit.trace()
|
||||||
|
|
||||||
|
./conformer_ctc3/export.py \
|
||||||
|
--exp-dir ./conformer_ctc3/exp \
|
||||||
|
--lang-dir data/lang_bpe_500 \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--jit-trace 1
|
||||||
|
|
||||||
|
It will generates the file: `jit_trace.pt`.
|
||||||
|
|
||||||
|
(2) Export `model.state_dict()`
|
||||||
|
|
||||||
|
./conformer_ctc3/export.py \
|
||||||
|
--exp-dir ./conformer_ctc3/exp \
|
||||||
|
--lang-dir data/lang_bpe_500 \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10
|
||||||
|
|
||||||
|
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||||
|
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||||
|
|
||||||
|
To use the generated file with `conformer_ctc3/decode.py`,
|
||||||
|
you can do:
|
||||||
|
|
||||||
|
cd /path/to/exp_dir
|
||||||
|
ln -s pretrained.pt epoch-9999.pt
|
||||||
|
|
||||||
|
cd /path/to/egs/librispeech/ASR
|
||||||
|
./conformer_ctc3/decode.py \
|
||||||
|
--exp-dir ./conformer_ctc3/exp \
|
||||||
|
--epoch 9999 \
|
||||||
|
--avg 1 \
|
||||||
|
--max-duration 100 \
|
||||||
|
--lang-dir data/lang_bpe_500
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
from train import add_model_arguments, get_ctc_model, get_params
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="""It specifies the checkpoint to use for averaging.
|
||||||
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless4/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=Path,
|
||||||
|
default="data/lang_bpe_500",
|
||||||
|
help="The lang dir containing word table and LG graph",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--jit-trace",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True to save a model after applying torch.jit.script.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--streaming-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to export a streaming model, if the models in exp-dir
|
||||||
|
are streaming model, this should be True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
max_token_id = max(lexicon.tokens)
|
||||||
|
num_classes = max_token_id + 1 # +1 for the blank
|
||||||
|
params.vocab_size = num_classes
|
||||||
|
|
||||||
|
if params.streaming_model:
|
||||||
|
assert params.causal_convolution
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_ctc_model(params)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if params.jit_trace:
|
||||||
|
# TODO: will support streaming mode
|
||||||
|
assert not params.streaming_model
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
|
||||||
|
logging.info("Using torch.jit.trace()")
|
||||||
|
|
||||||
|
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||||
|
traced_model = torch.jit.trace(model, (x, x_lens))
|
||||||
|
|
||||||
|
filename = params.exp_dir / "jit_trace.pt"
|
||||||
|
traced_model.save(str(filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
|
else:
|
||||||
|
logging.info("Not using torch.jit.trace()")
|
||||||
|
# Save it using a format so that it can be loaded
|
||||||
|
# by :func:`load_checkpoint`
|
||||||
|
filename = params.exp_dir / "pretrained.pt"
|
||||||
|
torch.save({"model": model.state_dict()}, str(filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
406
egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
Executable file
406
egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
Executable file
@ -0,0 +1,406 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Mingshuang Luo,)
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Usage (for non-streaming mode):
|
||||||
|
|
||||||
|
(1) ctc-decoding
|
||||||
|
./conformer_ctc3/pretrained.py \
|
||||||
|
--checkpoint conformer_ctc3/exp/pretrained.pt \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--method ctc-decoding \
|
||||||
|
--sample-rate 16000 \
|
||||||
|
test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
(2) 1best
|
||||||
|
./conformer_ctc3/pretrained.py \
|
||||||
|
--checkpoint conformer_ctc3/exp/pretrained.pt \
|
||||||
|
--HLG data/lang_bpe_500/HLG.pt \
|
||||||
|
--words-file data/lang_bpe_500/words.txt \
|
||||||
|
--method 1best \
|
||||||
|
--sample-rate 16000 \
|
||||||
|
test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
(3) nbest-rescoring
|
||||||
|
./conformer_ctc3/pretrained.py \
|
||||||
|
--checkpoint conformer_ctc3/exp/pretrained.pt \
|
||||||
|
--HLG data/lang_bpe_500/HLG.pt \
|
||||||
|
--words-file data/lang_bpe_500/words.txt \
|
||||||
|
--G data/lm/G_4_gram.pt \
|
||||||
|
--method nbest-rescoring \
|
||||||
|
--sample-rate 16000 \
|
||||||
|
test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
(4) whole-lattice-rescoring
|
||||||
|
./conformer_ctc3/pretrained.py \
|
||||||
|
--checkpoint conformer_ctc3/exp/pretrained.pt \
|
||||||
|
--HLG data/lang_bpe_500/HLG.pt \
|
||||||
|
--words-file data/lang_bpe_500/words.txt \
|
||||||
|
--G data/lm/G_4_gram.pt \
|
||||||
|
--method whole-lattice-rescoring \
|
||||||
|
--sample-rate 16000 \
|
||||||
|
test_wavs/1089-134686-0001.wav
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import kaldifeat
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from decode import get_decoding_params
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from train import add_model_arguments, get_params
|
||||||
|
|
||||||
|
from icefall.decode import (
|
||||||
|
get_lattice,
|
||||||
|
one_best_decoding,
|
||||||
|
rescore_with_n_best_list,
|
||||||
|
rescore_with_whole_lattice,
|
||||||
|
)
|
||||||
|
from icefall.utils import get_texts
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the torchscript model.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--words-file",
|
||||||
|
type=str,
|
||||||
|
help="""Path to words.txt.
|
||||||
|
Used only when method is not ctc-decoding.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--HLG",
|
||||||
|
type=str,
|
||||||
|
help="""Path to HLG.pt.
|
||||||
|
Used only when method is not ctc-decoding.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
help="""Path to bpe.model.
|
||||||
|
Used only when method is ctc-decoding.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--method",
|
||||||
|
type=str,
|
||||||
|
default="1best",
|
||||||
|
help="""Decoding method.
|
||||||
|
Possible values are:
|
||||||
|
(0) ctc-decoding - Use CTC decoding. It uses a sentence
|
||||||
|
piece model, i.e., lang_dir/bpe.model, to convert
|
||||||
|
word pieces to words. It needs neither a lexicon
|
||||||
|
nor an n-gram LM.
|
||||||
|
(1) 1best - Use the best path as decoding output. Only
|
||||||
|
the transformer encoder output is used for decoding.
|
||||||
|
We call it HLG decoding.
|
||||||
|
(2) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||||
|
rescore them with an LM, the path with
|
||||||
|
the highest score is the decoding result.
|
||||||
|
We call it HLG decoding + n-gram LM rescoring.
|
||||||
|
(3) whole-lattice-rescoring - Use an LM to rescore the
|
||||||
|
decoding lattice and then use 1best to decode the
|
||||||
|
rescored lattice.
|
||||||
|
We call it HLG decoding + n-gram LM rescoring.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--G",
|
||||||
|
type=str,
|
||||||
|
help="""An LM for rescoring.
|
||||||
|
Used only when method is
|
||||||
|
whole-lattice-rescoring or nbest-rescoring.
|
||||||
|
It's usually a 4-gram LM.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-paths",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="""
|
||||||
|
Used only when method is attention-decoder.
|
||||||
|
It specifies the size of n-best list.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=1.3,
|
||||||
|
help="""
|
||||||
|
Used only when method is whole-lattice-rescoring and nbest-rescoring.
|
||||||
|
It specifies the scale for n-gram LM scores.
|
||||||
|
(Note: You need to tune it on a dataset.)
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nbest-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="""
|
||||||
|
Used only when method is nbest-rescoring.
|
||||||
|
It specifies the scale for lattice.scores when
|
||||||
|
extracting n-best lists. A smaller value results in
|
||||||
|
more unique number of paths with the risk of missing
|
||||||
|
the best path.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-classes",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="""
|
||||||
|
Vocab size in the BPE model.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="The sample rate of the input sound file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert sample_rate == expected_sample_rate, (
|
||||||
|
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
# add decoding params
|
||||||
|
params.update(get_decoding_params())
|
||||||
|
params.update(vars(args))
|
||||||
|
params.vocab_size = params.num_classes
|
||||||
|
|
||||||
|
logging.info(f"{params}")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
model = torch.jit.load(args.model_filename)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = device
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = params.sample_rate
|
||||||
|
opts.mel_opts.num_bins = params.feature_dim
|
||||||
|
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {params.sound_files}")
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||||
|
)
|
||||||
|
waves = [w.to(device) for w in waves]
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
features = fbank(waves)
|
||||||
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
|
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||||
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
|
nnet_output, _ = model(features, feature_lengths)
|
||||||
|
|
||||||
|
batch_size = nnet_output.shape[0]
|
||||||
|
supervision_segments = torch.tensor(
|
||||||
|
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.method == "ctc-decoding":
|
||||||
|
logging.info("Use CTC decoding")
|
||||||
|
bpe_model = spm.SentencePieceProcessor()
|
||||||
|
bpe_model.load(params.bpe_model)
|
||||||
|
max_token_id = params.num_classes - 1
|
||||||
|
|
||||||
|
H = k2.ctc_topo(
|
||||||
|
max_token=max_token_id,
|
||||||
|
modified=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
lattice = get_lattice(
|
||||||
|
nnet_output=nnet_output,
|
||||||
|
decoding_graph=H,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
search_beam=params.search_beam,
|
||||||
|
output_beam=params.output_beam,
|
||||||
|
min_active_states=params.min_active_states,
|
||||||
|
max_active_states=params.max_active_states,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
best_path = one_best_decoding(
|
||||||
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
|
)
|
||||||
|
token_ids = get_texts(best_path)
|
||||||
|
hyps = bpe_model.decode(token_ids)
|
||||||
|
hyps = [s.split() for s in hyps]
|
||||||
|
elif params.method in [
|
||||||
|
"1best",
|
||||||
|
"nbest-rescoring",
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
]:
|
||||||
|
logging.info(f"Loading HLG from {params.HLG}")
|
||||||
|
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||||
|
HLG = HLG.to(device)
|
||||||
|
if not hasattr(HLG, "lm_scores"):
|
||||||
|
# For whole-lattice-rescoring and attention-decoder
|
||||||
|
HLG.lm_scores = HLG.scores.clone()
|
||||||
|
|
||||||
|
if params.method in [
|
||||||
|
"nbest-rescoring",
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
]:
|
||||||
|
logging.info(f"Loading G from {params.G}")
|
||||||
|
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||||
|
G = G.to(device)
|
||||||
|
if params.method == "whole-lattice-rescoring":
|
||||||
|
# Add epsilon self-loops to G as we will compose
|
||||||
|
# it with the whole lattice later
|
||||||
|
G = k2.add_epsilon_self_loops(G)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
|
||||||
|
# G.lm_scores is used to replace HLG.lm_scores during
|
||||||
|
# LM rescoring.
|
||||||
|
G.lm_scores = G.scores.clone()
|
||||||
|
|
||||||
|
lattice = get_lattice(
|
||||||
|
nnet_output=nnet_output,
|
||||||
|
decoding_graph=HLG,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
search_beam=params.search_beam,
|
||||||
|
output_beam=params.output_beam,
|
||||||
|
min_active_states=params.min_active_states,
|
||||||
|
max_active_states=params.max_active_states,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.method == "1best":
|
||||||
|
logging.info("Use HLG decoding")
|
||||||
|
best_path = one_best_decoding(
|
||||||
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
|
)
|
||||||
|
if params.method == "nbest-rescoring":
|
||||||
|
logging.info("Use HLG decoding + LM rescoring")
|
||||||
|
best_path_dict = rescore_with_n_best_list(
|
||||||
|
lattice=lattice,
|
||||||
|
G=G,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
lm_scale_list=[params.ngram_lm_scale],
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
elif params.method == "whole-lattice-rescoring":
|
||||||
|
logging.info("Use HLG decoding + LM rescoring")
|
||||||
|
best_path_dict = rescore_with_whole_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
G_with_epsilon_loops=G,
|
||||||
|
lm_scale_list=[params.ngram_lm_scale],
|
||||||
|
)
|
||||||
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||||
|
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||||
|
|
||||||
|
s = "\n"
|
||||||
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
|
words = " ".join(hyp)
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
1
egs/librispeech/ASR/conformer_ctc3/lstmp.py
Symbolic link
1
egs/librispeech/ASR/conformer_ctc3/lstmp.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../lstm_transducer_stateless2/lstmp.py
|
122
egs/librispeech/ASR/conformer_ctc3/model.py
Normal file
122
egs/librispeech/ASR/conformer_ctc3/model.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Wei Kang,
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from encoder_interface import EncoderInterface
|
||||||
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
|
|
||||||
|
class CTCModel(nn.Module):
|
||||||
|
"""It implements https://www.cs.toronto.edu/~graves/icml_2006.pdf
|
||||||
|
"Connectionist Temporal Classification: Labelling Unsegmented
|
||||||
|
Sequence Data with Recurrent Neural Networks"
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder: EncoderInterface,
|
||||||
|
encoder_dim: int,
|
||||||
|
vocab_size: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder:
|
||||||
|
It is the transcription network in the paper. Its accepts
|
||||||
|
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||||
|
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||||
|
`logit_lens` of shape (N,).
|
||||||
|
encoder_dim:
|
||||||
|
The feature embedding dimension.
|
||||||
|
vocab_size:
|
||||||
|
The vocabulary size.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||||
|
|
||||||
|
self.encoder = encoder
|
||||||
|
self.ctc_output_module = nn.Sequential(
|
||||||
|
nn.Dropout(p=0.1),
|
||||||
|
ScaledLinear(encoder_dim, vocab_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_ctc_output(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
delay_penalty: float = 0.0,
|
||||||
|
blank_threshold: float = 0.99,
|
||||||
|
):
|
||||||
|
"""Compute ctc log-prob and optionally (delay_penalty > 0) apply delay penalty.
|
||||||
|
We first split utterance into sub-utterances according to the
|
||||||
|
blank probs, and then add sawtooth-like "blank-bonus" values to
|
||||||
|
the blank probs.
|
||||||
|
See https://github.com/k2-fsa/icefall/pull/669 for details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
A tensor with shape of (N, T, C).
|
||||||
|
delay_penalty:
|
||||||
|
A constant used to scale the delay penalty score.
|
||||||
|
blank_threshold:
|
||||||
|
The threshold used to split utterance into sub-utterances.
|
||||||
|
"""
|
||||||
|
output = self.ctc_output_module(encoder_out)
|
||||||
|
log_prob = nn.functional.log_softmax(output, dim=-1)
|
||||||
|
|
||||||
|
if self.training and delay_penalty > 0:
|
||||||
|
T_arange = torch.arange(encoder_out.shape[1]).to(device=encoder_out.device)
|
||||||
|
# split into sub-utterances using the blank-id
|
||||||
|
mask = log_prob[:, :, 0] >= math.log(blank_threshold) # (B, T)
|
||||||
|
mask[:, 0] = True
|
||||||
|
cummax_out = (T_arange * mask).cummax(dim=-1)[0] # (B, T)
|
||||||
|
# the sawtooth "blank-bonus" value
|
||||||
|
penalty = T_arange - cummax_out # (B, T)
|
||||||
|
penalty_all = torch.zeros_like(log_prob)
|
||||||
|
penalty_all[:, :, 0] = delay_penalty * penalty
|
||||||
|
# apply latency penalty on probs
|
||||||
|
log_prob = log_prob + penalty_all
|
||||||
|
|
||||||
|
return log_prob
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
warmup: float = 1.0,
|
||||||
|
delay_penalty: float = 0.0,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C).
|
||||||
|
x_lens:
|
||||||
|
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||||
|
before padding.
|
||||||
|
warmup: a floating point value which increases throughout training;
|
||||||
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
|
delay_penalty:
|
||||||
|
A constant used to scale the delay penalty score.
|
||||||
|
"""
|
||||||
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens, warmup=warmup)
|
||||||
|
assert torch.all(encoder_out_lens > 0)
|
||||||
|
nnet_output = self.get_ctc_output(encoder_out, delay_penalty=delay_penalty)
|
||||||
|
return nnet_output, encoder_out_lens
|
1
egs/librispeech/ASR/conformer_ctc3/optim.py
Symbolic link
1
egs/librispeech/ASR/conformer_ctc3/optim.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless2/optim.py
|
458
egs/librispeech/ASR/conformer_ctc3/pretrained.py
Executable file
458
egs/librispeech/ASR/conformer_ctc3/pretrained.py
Executable file
@ -0,0 +1,458 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Mingshuang Luo,)
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Usage (for non-streaming mode):
|
||||||
|
|
||||||
|
(1) ctc-decoding
|
||||||
|
./conformer_ctc3/pretrained.py \
|
||||||
|
--checkpoint conformer_ctc3/exp/pretrained.pt \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--method ctc-decoding \
|
||||||
|
--sample-rate 16000 \
|
||||||
|
test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
(2) 1best
|
||||||
|
./conformer_ctc3/pretrained.py \
|
||||||
|
--checkpoint conformer_ctc3/exp/pretrained.pt \
|
||||||
|
--HLG data/lang_bpe_500/HLG.pt \
|
||||||
|
--words-file data/lang_bpe_500/words.txt \
|
||||||
|
--method 1best \
|
||||||
|
--sample-rate 16000 \
|
||||||
|
test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
(3) nbest-rescoring
|
||||||
|
./conformer_ctc3/pretrained.py \
|
||||||
|
--checkpoint conformer_ctc3/exp/pretrained.pt \
|
||||||
|
--HLG data/lang_bpe_500/HLG.pt \
|
||||||
|
--words-file data/lang_bpe_500/words.txt \
|
||||||
|
--G data/lm/G_4_gram.pt \
|
||||||
|
--method nbest-rescoring \
|
||||||
|
--sample-rate 16000 \
|
||||||
|
test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
(4) whole-lattice-rescoring
|
||||||
|
./conformer_ctc3/pretrained.py \
|
||||||
|
--checkpoint conformer_ctc3/exp/pretrained.pt \
|
||||||
|
--HLG data/lang_bpe_500/HLG.pt \
|
||||||
|
--words-file data/lang_bpe_500/words.txt \
|
||||||
|
--G data/lm/G_4_gram.pt \
|
||||||
|
--method whole-lattice-rescoring \
|
||||||
|
--sample-rate 16000 \
|
||||||
|
test_wavs/1089-134686-0001.wav
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import kaldifeat
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from decode import get_decoding_params
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from train import add_model_arguments, get_ctc_model, get_params
|
||||||
|
|
||||||
|
from icefall.decode import (
|
||||||
|
get_lattice,
|
||||||
|
one_best_decoding,
|
||||||
|
rescore_with_n_best_list,
|
||||||
|
rescore_with_whole_lattice,
|
||||||
|
)
|
||||||
|
from icefall.utils import get_texts, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the checkpoint. "
|
||||||
|
"The checkpoint is assumed to be saved by "
|
||||||
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--words-file",
|
||||||
|
type=str,
|
||||||
|
help="""Path to words.txt.
|
||||||
|
Used only when method is not ctc-decoding.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--HLG",
|
||||||
|
type=str,
|
||||||
|
help="""Path to HLG.pt.
|
||||||
|
Used only when method is not ctc-decoding.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
help="""Path to bpe.model.
|
||||||
|
Used only when method is ctc-decoding.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--method",
|
||||||
|
type=str,
|
||||||
|
default="1best",
|
||||||
|
help="""Decoding method.
|
||||||
|
Possible values are:
|
||||||
|
(0) ctc-decoding - Use CTC decoding. It uses a sentence
|
||||||
|
piece model, i.e., lang_dir/bpe.model, to convert
|
||||||
|
word pieces to words. It needs neither a lexicon
|
||||||
|
nor an n-gram LM.
|
||||||
|
(1) 1best - Use the best path as decoding output. Only
|
||||||
|
the transformer encoder output is used for decoding.
|
||||||
|
We call it HLG decoding.
|
||||||
|
(2) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||||
|
rescore them with an LM, the path with
|
||||||
|
the highest score is the decoding result.
|
||||||
|
We call it HLG decoding + n-gram LM rescoring.
|
||||||
|
(3) whole-lattice-rescoring - Use an LM to rescore the
|
||||||
|
decoding lattice and then use 1best to decode the
|
||||||
|
rescored lattice.
|
||||||
|
We call it HLG decoding + n-gram LM rescoring.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--G",
|
||||||
|
type=str,
|
||||||
|
help="""An LM for rescoring.
|
||||||
|
Used only when method is
|
||||||
|
whole-lattice-rescoring or nbest-rescoring.
|
||||||
|
It's usually a 4-gram LM.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-paths",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="""
|
||||||
|
Used only when method is attention-decoder.
|
||||||
|
It specifies the size of n-best list.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=1.3,
|
||||||
|
help="""
|
||||||
|
Used only when method is whole-lattice-rescoring and nbest-rescoring.
|
||||||
|
It specifies the scale for n-gram LM scores.
|
||||||
|
(Note: You need to tune it on a dataset.)
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nbest-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="""
|
||||||
|
Used only when method is nbest-rescoring.
|
||||||
|
It specifies the scale for lattice.scores when
|
||||||
|
extracting n-best lists. A smaller value results in
|
||||||
|
more unique number of paths with the risk of missing
|
||||||
|
the best path.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-classes",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="""
|
||||||
|
Vocab size in the BPE model.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--simulate-streaming",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to simulate streaming in decoding, this is a good way to
|
||||||
|
test a streaming model.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="The chunk size for decoding (in frames after subsampling)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--left-context",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="left context can be seen during decoding (in frames after subsampling)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="The sample rate of the input sound file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert sample_rate == expected_sample_rate, (
|
||||||
|
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
# add decoding params
|
||||||
|
params.update(get_decoding_params())
|
||||||
|
params.update(vars(args))
|
||||||
|
params.vocab_size = params.num_classes
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
|
assert (
|
||||||
|
params.causal_convolution
|
||||||
|
), "Decoding in streaming requires causal convolution"
|
||||||
|
|
||||||
|
logging.info(f"{params}")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_ctc_model(params)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||||
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = device
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = params.sample_rate
|
||||||
|
opts.mel_opts.num_bins = params.feature_dim
|
||||||
|
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {params.sound_files}")
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||||
|
)
|
||||||
|
waves = [w.to(device) for w in waves]
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
features = fbank(waves)
|
||||||
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
|
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||||
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
|
# model forward
|
||||||
|
if params.simulate_streaming:
|
||||||
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
|
x=features,
|
||||||
|
x_lens=feature_lengths,
|
||||||
|
chunk_size=params.decode_chunk_size,
|
||||||
|
left_context=params.left_context,
|
||||||
|
simulate_streaming=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=features, x_lens=feature_lengths
|
||||||
|
)
|
||||||
|
nnet_output = model.get_ctc_output(encoder_out)
|
||||||
|
|
||||||
|
batch_size = nnet_output.shape[0]
|
||||||
|
supervision_segments = torch.tensor(
|
||||||
|
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.method == "ctc-decoding":
|
||||||
|
logging.info("Use CTC decoding")
|
||||||
|
bpe_model = spm.SentencePieceProcessor()
|
||||||
|
bpe_model.load(params.bpe_model)
|
||||||
|
max_token_id = params.num_classes - 1
|
||||||
|
|
||||||
|
H = k2.ctc_topo(
|
||||||
|
max_token=max_token_id,
|
||||||
|
modified=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
lattice = get_lattice(
|
||||||
|
nnet_output=nnet_output,
|
||||||
|
decoding_graph=H,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
search_beam=params.search_beam,
|
||||||
|
output_beam=params.output_beam,
|
||||||
|
min_active_states=params.min_active_states,
|
||||||
|
max_active_states=params.max_active_states,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
best_path = one_best_decoding(
|
||||||
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
|
)
|
||||||
|
token_ids = get_texts(best_path)
|
||||||
|
hyps = bpe_model.decode(token_ids)
|
||||||
|
hyps = [s.split() for s in hyps]
|
||||||
|
elif params.method in [
|
||||||
|
"1best",
|
||||||
|
"nbest-rescoring",
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
]:
|
||||||
|
logging.info(f"Loading HLG from {params.HLG}")
|
||||||
|
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||||
|
HLG = HLG.to(device)
|
||||||
|
if not hasattr(HLG, "lm_scores"):
|
||||||
|
# For whole-lattice-rescoring and attention-decoder
|
||||||
|
HLG.lm_scores = HLG.scores.clone()
|
||||||
|
|
||||||
|
if params.method in [
|
||||||
|
"nbest-rescoring",
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
]:
|
||||||
|
logging.info(f"Loading G from {params.G}")
|
||||||
|
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||||
|
G = G.to(device)
|
||||||
|
if params.method == "whole-lattice-rescoring":
|
||||||
|
# Add epsilon self-loops to G as we will compose
|
||||||
|
# it with the whole lattice later
|
||||||
|
G = k2.add_epsilon_self_loops(G)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
|
||||||
|
# G.lm_scores is used to replace HLG.lm_scores during
|
||||||
|
# LM rescoring.
|
||||||
|
G.lm_scores = G.scores.clone()
|
||||||
|
|
||||||
|
lattice = get_lattice(
|
||||||
|
nnet_output=nnet_output,
|
||||||
|
decoding_graph=HLG,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
search_beam=params.search_beam,
|
||||||
|
output_beam=params.output_beam,
|
||||||
|
min_active_states=params.min_active_states,
|
||||||
|
max_active_states=params.max_active_states,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.method == "1best":
|
||||||
|
logging.info("Use HLG decoding")
|
||||||
|
best_path = one_best_decoding(
|
||||||
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
|
)
|
||||||
|
if params.method == "nbest-rescoring":
|
||||||
|
logging.info("Use HLG decoding + LM rescoring")
|
||||||
|
best_path_dict = rescore_with_n_best_list(
|
||||||
|
lattice=lattice,
|
||||||
|
G=G,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
lm_scale_list=[params.ngram_lm_scale],
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
elif params.method == "whole-lattice-rescoring":
|
||||||
|
logging.info("Use HLG decoding + LM rescoring")
|
||||||
|
best_path_dict = rescore_with_whole_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
G_with_epsilon_loops=G,
|
||||||
|
lm_scale_list=[params.ngram_lm_scale],
|
||||||
|
)
|
||||||
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||||
|
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||||
|
|
||||||
|
s = "\n"
|
||||||
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
|
words = " ".join(hyp)
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
1
egs/librispeech/ASR/conformer_ctc3/scaling.py
Symbolic link
1
egs/librispeech/ASR/conformer_ctc3/scaling.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless2/scaling.py
|
1
egs/librispeech/ASR/conformer_ctc3/scaling_converter.py
Symbolic link
1
egs/librispeech/ASR/conformer_ctc3/scaling_converter.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless3/scaling_converter.py
|
82
egs/librispeech/ASR/conformer_ctc3/test_model.py
Executable file
82
egs/librispeech/ASR/conformer_ctc3/test_model.py
Executable file
@ -0,0 +1,82 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./conformer_ctc3/test_model.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from train import get_params, get_ctc_model
|
||||||
|
|
||||||
|
|
||||||
|
def test_model():
|
||||||
|
params = get_params()
|
||||||
|
params.vocab_size = 500
|
||||||
|
params.blank_id = 0
|
||||||
|
params.context_size = 2
|
||||||
|
params.unk_id = 2
|
||||||
|
|
||||||
|
params.dynamic_chunk_training = False
|
||||||
|
params.short_chunk_size = 25
|
||||||
|
params.num_left_chunks = 4
|
||||||
|
params.causal_convolution = False
|
||||||
|
|
||||||
|
model = get_ctc_model(params)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
print(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
features = torch.randn(2, 100, 80)
|
||||||
|
feature_lengths = torch.full((2,), 100)
|
||||||
|
model(x=features, x_lens=feature_lengths)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_streaming():
|
||||||
|
params = get_params()
|
||||||
|
params.vocab_size = 500
|
||||||
|
params.blank_id = 0
|
||||||
|
params.context_size = 2
|
||||||
|
params.unk_id = 2
|
||||||
|
|
||||||
|
params.dynamic_chunk_training = True
|
||||||
|
params.short_chunk_size = 25
|
||||||
|
params.num_left_chunks = 4
|
||||||
|
params.causal_convolution = True
|
||||||
|
|
||||||
|
model = get_ctc_model(params)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
print(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
features = torch.randn(2, 100, 80)
|
||||||
|
feature_lengths = torch.full((2,), 100)
|
||||||
|
encoder_out, _ = model.encoder(x=features, x_lens=feature_lengths)
|
||||||
|
model.get_ctc_output(encoder_out)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_model()
|
||||||
|
test_model_streaming()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1108
egs/librispeech/ASR/conformer_ctc3/train.py
Executable file
1108
egs/librispeech/ASR/conformer_ctc3/train.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -83,11 +83,12 @@ class BpeCtcTrainingGraphCompiler(object):
|
|||||||
Args:
|
Args:
|
||||||
piece_ids:
|
piece_ids:
|
||||||
It is a list-of-list integer IDs.
|
It is a list-of-list integer IDs.
|
||||||
modified:
|
modified:
|
||||||
See :func:`k2.ctc_graph` for its meaning.
|
See :func:`k2.ctc_graph` for its meaning.
|
||||||
Return:
|
Return:
|
||||||
Return an FsaVec, which is the result of composing a
|
Return an FsaVec, which is the result of composing a
|
||||||
CTC topology with linear FSAs constructed from the given
|
CTC topology with linear FSAs constructed from the given
|
||||||
piece IDs.
|
piece IDs.
|
||||||
"""
|
"""
|
||||||
return k2.ctc_graph(piece_ids, modified=modified, device=self.device)
|
graph = k2.ctc_graph(piece_ids, modified=modified, device=self.device)
|
||||||
|
return graph
|
||||||
|
@ -117,4 +117,5 @@ class CharCtcTrainingGraphCompiler(object):
|
|||||||
CTC topology with linear FSAs constructed from the given
|
CTC topology with linear FSAs constructed from the given
|
||||||
piece IDs.
|
piece IDs.
|
||||||
"""
|
"""
|
||||||
return k2.ctc_graph(token_ids, modified=modified, device=self.device)
|
graph = k2.ctc_graph(token_ids, modified=modified, device=self.device)
|
||||||
|
return graph
|
||||||
|
@ -298,7 +298,7 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
|
|||||||
if not result:
|
if not result:
|
||||||
logging.warn(f"Invalid checkpoint filename {c}")
|
logging.warn(f"Invalid checkpoint filename {c}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
iter_checkpoints.append((int(result.group(1)), c))
|
iter_checkpoints.append((int(result.group(1)), c))
|
||||||
|
|
||||||
# iter_checkpoints is a list of tuples. Each tuple contains
|
# iter_checkpoints is a list of tuples. Each tuple contains
|
||||||
|
@ -79,6 +79,10 @@ class CtcTrainingGraphCompiler(object):
|
|||||||
|
|
||||||
fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
|
fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
|
||||||
|
|
||||||
|
self.ctc_topo._is_repeat_token_ = (
|
||||||
|
self.ctc_topo.labels != self.ctc_topo.aux_labels
|
||||||
|
)
|
||||||
|
|
||||||
decoding_graph = k2.compose(
|
decoding_graph = k2.compose(
|
||||||
self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False
|
self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False
|
||||||
)
|
)
|
||||||
|
@ -670,8 +670,8 @@ def write_error_stats_with_timestamps(
|
|||||||
all_delay = []
|
all_delay = []
|
||||||
for cut_id, ref, hyp, time_ref, time_hyp in results:
|
for cut_id, ref, hyp, time_ref, time_hyp in results:
|
||||||
ali = kaldialign.align(ref, hyp, ERR)
|
ali = kaldialign.align(ref, hyp, ERR)
|
||||||
has_time_ref = len(time_ref) > 0
|
has_time = len(time_ref) > 0 and len(time_hyp) > 0
|
||||||
if has_time_ref:
|
if has_time:
|
||||||
# pointer to timestamp_hyp
|
# pointer to timestamp_hyp
|
||||||
p_hyp = 0
|
p_hyp = 0
|
||||||
# pointer to timestamp_ref
|
# pointer to timestamp_ref
|
||||||
@ -680,28 +680,28 @@ def write_error_stats_with_timestamps(
|
|||||||
if ref_word == ERR:
|
if ref_word == ERR:
|
||||||
ins[hyp_word] += 1
|
ins[hyp_word] += 1
|
||||||
words[hyp_word][3] += 1
|
words[hyp_word][3] += 1
|
||||||
if has_time_ref:
|
if has_time:
|
||||||
p_hyp += 1
|
p_hyp += 1
|
||||||
elif hyp_word == ERR:
|
elif hyp_word == ERR:
|
||||||
dels[ref_word] += 1
|
dels[ref_word] += 1
|
||||||
words[ref_word][4] += 1
|
words[ref_word][4] += 1
|
||||||
if has_time_ref:
|
if has_time:
|
||||||
p_ref += 1
|
p_ref += 1
|
||||||
elif hyp_word != ref_word:
|
elif hyp_word != ref_word:
|
||||||
subs[(ref_word, hyp_word)] += 1
|
subs[(ref_word, hyp_word)] += 1
|
||||||
words[ref_word][1] += 1
|
words[ref_word][1] += 1
|
||||||
words[hyp_word][2] += 1
|
words[hyp_word][2] += 1
|
||||||
if has_time_ref:
|
if has_time:
|
||||||
p_hyp += 1
|
p_hyp += 1
|
||||||
p_ref += 1
|
p_ref += 1
|
||||||
else:
|
else:
|
||||||
words[ref_word][0] += 1
|
words[ref_word][0] += 1
|
||||||
num_corr += 1
|
num_corr += 1
|
||||||
if has_time_ref:
|
if has_time:
|
||||||
all_delay.append(time_hyp[p_hyp] - time_ref[p_ref])
|
all_delay.append(time_hyp[p_hyp] - time_ref[p_ref])
|
||||||
p_hyp += 1
|
p_hyp += 1
|
||||||
p_ref += 1
|
p_ref += 1
|
||||||
if has_time_ref:
|
if has_time:
|
||||||
assert p_hyp == len(hyp), (p_hyp, len(hyp))
|
assert p_hyp == len(hyp), (p_hyp, len(hyp))
|
||||||
assert p_ref == len(ref), (p_ref, len(ref))
|
assert p_ref == len(ref), (p_ref, len(ref))
|
||||||
|
|
||||||
@ -1327,10 +1327,9 @@ def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
|
|||||||
|
|
||||||
def parse_hyp_and_timestamp(
|
def parse_hyp_and_timestamp(
|
||||||
res: DecodingResults,
|
res: DecodingResults,
|
||||||
decoding_method: str,
|
|
||||||
sp: spm.SentencePieceProcessor,
|
|
||||||
subsampling_factor: int,
|
subsampling_factor: int,
|
||||||
frame_shift_ms: float = 10,
|
frame_shift_ms: float = 10,
|
||||||
|
sp: Optional[spm.SentencePieceProcessor] = None,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
) -> Tuple[List[List[str]], List[List[float]]]:
|
) -> Tuple[List[List[str]], List[List[float]]]:
|
||||||
"""Parse hypothesis and timestamp.
|
"""Parse hypothesis and timestamp.
|
||||||
@ -1338,51 +1337,29 @@ def parse_hyp_and_timestamp(
|
|||||||
Args:
|
Args:
|
||||||
res:
|
res:
|
||||||
A DecodingResults object.
|
A DecodingResults object.
|
||||||
decoding_method:
|
|
||||||
Possible values are:
|
|
||||||
- greedy_search
|
|
||||||
- beam_search
|
|
||||||
- modified_beam_search
|
|
||||||
- fast_beam_search
|
|
||||||
- fast_beam_search_LG
|
|
||||||
- fast_beam_search_nbest
|
|
||||||
- fast_beam_search_nbest_oracle
|
|
||||||
- fast_beam_search_nbest_LG
|
|
||||||
sp:
|
|
||||||
The BPE model.
|
|
||||||
subsampling_factor:
|
subsampling_factor:
|
||||||
The integer subsampling factor.
|
The integer subsampling factor.
|
||||||
frame_shift_ms:
|
frame_shift_ms:
|
||||||
The float frame shift used for feature extraction.
|
The float frame shift used for feature extraction.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
word_table:
|
word_table:
|
||||||
The word symbol table.
|
The word symbol table.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a list of hypothesis and timestamp.
|
Return a list of hypothesis and timestamp.
|
||||||
"""
|
"""
|
||||||
assert decoding_method in (
|
|
||||||
"greedy_search",
|
|
||||||
"beam_search",
|
|
||||||
"fast_beam_search",
|
|
||||||
"fast_beam_search_LG",
|
|
||||||
"fast_beam_search_nbest",
|
|
||||||
"fast_beam_search_nbest_LG",
|
|
||||||
"fast_beam_search_nbest_oracle",
|
|
||||||
"modified_beam_search",
|
|
||||||
)
|
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
timestamps = []
|
timestamps = []
|
||||||
|
|
||||||
N = len(res.hyps)
|
N = len(res.hyps)
|
||||||
assert len(res.timestamps) == N, (len(res.timestamps), N)
|
assert len(res.timestamps) == N, (len(res.timestamps), N)
|
||||||
use_word_table = False
|
use_word_table = False
|
||||||
if (
|
if word_table is not None:
|
||||||
decoding_method == "fast_beam_search_nbest_LG"
|
assert sp is None
|
||||||
and decoding_method == "fast_beam_search_LG"
|
|
||||||
):
|
|
||||||
assert word_table is not None
|
|
||||||
use_word_table = True
|
use_word_table = True
|
||||||
|
else:
|
||||||
|
assert sp is not None and word_table is None
|
||||||
|
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
time = convert_timestamp(res.timestamps[i], subsampling_factor, frame_shift_ms)
|
time = convert_timestamp(res.timestamps[i], subsampling_factor, frame_shift_ms)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user