mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Merge branch 'master' of github.com:marcoyang1998/icefall into libriheavy_prompt_asr
This commit is contained in:
commit
03a632fc30
1
.flake8
1
.flake8
@ -24,6 +24,7 @@ exclude =
|
||||
**/data/**,
|
||||
icefall/shared/make_kn_lm.py,
|
||||
icefall/__init__.py
|
||||
icefall/ctc/__init__.py
|
||||
|
||||
ignore =
|
||||
# E203 white space before ":"
|
||||
|
46
.github/scripts/run-pre-trained-conformer-ctc.sh
vendored
46
.github/scripts/run-pre-trained-conformer-ctc.sh
vendored
@ -1,46 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500
|
||||
git lfs install
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
ls -lh $repo/test_wavs/*.flac
|
||||
|
||||
log "CTC decoding"
|
||||
|
||||
./conformer_ctc/pretrained.py \
|
||||
--method ctc-decoding \
|
||||
--num-classes 500 \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
$repo/test_wavs/1089-134686-0001.flac \
|
||||
$repo/test_wavs/1221-135766-0001.flac \
|
||||
$repo/test_wavs/1221-135766-0002.flac
|
||||
|
||||
log "HLG decoding"
|
||||
|
||||
./conformer_ctc/pretrained.py \
|
||||
--method 1best \
|
||||
--num-classes 500 \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||
$repo/test_wavs/1089-134686-0001.flac \
|
||||
$repo/test_wavs/1221-135766-0001.flac \
|
||||
$repo/test_wavs/1221-135766-0002.flac
|
240
.github/scripts/run-pre-trained-ctc.sh
vendored
Executable file
240
.github/scripts/run-pre-trained-ctc.sh
vendored
Executable file
@ -0,0 +1,240 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
pushd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git lfs install
|
||||
git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
log "CTC greedy search"
|
||||
|
||||
./zipformer/onnx_pretrained_ctc.py \
|
||||
--nn-model $repo/model.onnx \
|
||||
--tokens $repo/tokens.txt \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/2.wav
|
||||
|
||||
log "CTC H decoding"
|
||||
|
||||
./zipformer/onnx_pretrained_ctc_H.py \
|
||||
--nn-model $repo/model.onnx \
|
||||
--tokens $repo/tokens.txt \
|
||||
--H $repo/H.fst \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/2.wav
|
||||
|
||||
log "CTC HL decoding"
|
||||
|
||||
./zipformer/onnx_pretrained_ctc_HL.py \
|
||||
--nn-model $repo/model.onnx \
|
||||
--words $repo/words.txt \
|
||||
--HL $repo/HL.fst \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/2.wav
|
||||
|
||||
log "CTC HLG decoding"
|
||||
|
||||
./zipformer/onnx_pretrained_ctc_HLG.py \
|
||||
--nn-model $repo/model.onnx \
|
||||
--words $repo/words.txt \
|
||||
--HLG $repo/HLG.fst \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/2.wav
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
pushd $repo
|
||||
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/HLG.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/L.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/L_disambig.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/Linv.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "data/lang_bpe_500/lexicon.txt"
|
||||
git lfs pull --include "data/lang_bpe_500/lexicon_disambig.txt"
|
||||
git lfs pull --include "data/lang_bpe_500/tokens.txt"
|
||||
git lfs pull --include "data/lang_bpe_500/words.txt"
|
||||
git lfs pull --include "data/lm/G_3_gram.fst.txt"
|
||||
|
||||
popd
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
log "CTC decoding"
|
||||
|
||||
./conformer_ctc/pretrained.py \
|
||||
--method ctc-decoding \
|
||||
--num-classes 500 \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
|
||||
log "HLG decoding"
|
||||
|
||||
./conformer_ctc/pretrained.py \
|
||||
--method 1best \
|
||||
--num-classes 500 \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
|
||||
log "CTC decoding on CPU with kaldi decoders using OpenFst"
|
||||
|
||||
log "Exporting model with torchscript"
|
||||
|
||||
pushd $repo/exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
./conformer_ctc/export.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--jit 1
|
||||
|
||||
ls -lh $repo/exp
|
||||
|
||||
|
||||
log "Generating H.fst, HL.fst"
|
||||
|
||||
./local/prepare_lang_fst.py --lang-dir $repo/data/lang_bpe_500 --ngram-G $repo/data/lm/G_3_gram.fst.txt
|
||||
|
||||
ls -lh $repo/data/lang_bpe_500
|
||||
|
||||
log "Decoding with H on CPU with OpenFst"
|
||||
|
||||
./conformer_ctc/jit_pretrained_decode_with_H.py \
|
||||
--nn-model $repo/exp/cpu_jit.pt \
|
||||
--H $repo/data/lang_bpe_500/H.fst \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
|
||||
log "Decoding with HL on CPU with OpenFst"
|
||||
|
||||
./conformer_ctc/jit_pretrained_decode_with_HL.py \
|
||||
--nn-model $repo/exp/cpu_jit.pt \
|
||||
--HL $repo/data/lang_bpe_500/HL.fst \
|
||||
--words $repo/data/lang_bpe_500/words.txt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
|
||||
log "Decoding with HLG on CPU with OpenFst"
|
||||
|
||||
./conformer_ctc/jit_pretrained_decode_with_HLG.py \
|
||||
--nn-model $repo/exp/cpu_jit.pt \
|
||||
--HLG $repo/data/lang_bpe_500/HLG.fst \
|
||||
--words $repo/data/lang_bpe_500/words.txt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
popd
|
||||
|
||||
log "Test aishell"
|
||||
|
||||
pushd egs/aishell/ASR
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/icefall_asr_aishell_conformer_ctc
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
pushd $repo
|
||||
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
git lfs pull --include "data/lang_char/H.fst"
|
||||
git lfs pull --include "data/lang_char/HL.fst"
|
||||
git lfs pull --include "data/lang_char/HLG.fst"
|
||||
|
||||
popd
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
log "CTC decoding"
|
||||
|
||||
log "Exporting model with torchscript"
|
||||
|
||||
pushd $repo/exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
./conformer_ctc/export.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
--tokens $repo/data/lang_char/tokens.txt \
|
||||
--jit 1
|
||||
|
||||
ls -lh $repo/exp
|
||||
|
||||
ls -lh $repo/data/lang_char
|
||||
|
||||
log "Decoding with H on CPU with OpenFst"
|
||||
|
||||
./conformer_ctc/jit_pretrained_decode_with_H.py \
|
||||
--nn-model $repo/exp/cpu_jit.pt \
|
||||
--H $repo/data/lang_char/H.fst \
|
||||
--tokens $repo/data/lang_char/tokens.txt \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/2.wav
|
||||
|
||||
log "Decoding with HL on CPU with OpenFst"
|
||||
|
||||
./conformer_ctc/jit_pretrained_decode_with_HL.py \
|
||||
--nn-model $repo/exp/cpu_jit.pt \
|
||||
--HL $repo/data/lang_char/HL.fst \
|
||||
--words $repo/data/lang_char/words.txt \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/2.wav
|
||||
|
||||
log "Decoding with HLG on CPU with OpenFst"
|
||||
|
||||
./conformer_ctc/jit_pretrained_decode_with_HLG.py \
|
||||
--nn-model $repo/exp/cpu_jit.pt \
|
||||
--HLG $repo/data/lang_char/HLG.fst \
|
||||
--words $repo/data/lang_char/words.txt \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/2.wav
|
||||
|
||||
rm -rf $repo
|
44
.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh
vendored
Executable file
44
.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh
vendored
Executable file
@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/swbd/ASR
|
||||
|
||||
repo_url=https://huggingface.co/zrjin/icefall-asr-swbd-conformer-ctc-2023-8-26
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git lfs install
|
||||
git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
pushd $repo/exp
|
||||
ln -s epoch-98.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
ls -lh $repo/exp/*.pt
|
||||
|
||||
for method in ctc-decoding 1best; do
|
||||
log "$method"
|
||||
|
||||
./conformer_ctc/pretrained.py \
|
||||
--method $method \
|
||||
--checkpoint $repo/exp/epoch-99.pt \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||
--G $repo/data/lm/G_4_gram.pt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
done
|
@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: run-pre-trained-conformer-ctc
|
||||
name: run-pre-trained-ctc
|
||||
|
||||
on:
|
||||
push:
|
||||
@ -23,13 +23,20 @@ on:
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
test-run:
|
||||
description: 'Test (y/n)?'
|
||||
required: true
|
||||
default: 'y'
|
||||
|
||||
concurrency:
|
||||
group: run_pre_trained_conformer_ctc-${{ github.ref }}
|
||||
group: run_pre_trained_ctc-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
run_pre_trained_conformer_ctc:
|
||||
if: github.event.label.name == 'ready' || github.event_name == 'push'
|
||||
run_pre_trained_ctc:
|
||||
if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.inputs.test-run == 'y' || github.event.label.name == 'ctc'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
@ -77,4 +84,4 @@ jobs:
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
.github/scripts/run-pre-trained-conformer-ctc.sh
|
||||
.github/scripts/run-pre-trained-ctc.sh
|
84
.github/workflows/run-swbd-conformer-ctc.yml
vendored
Normal file
84
.github/workflows/run-swbd-conformer-ctc.yml
vendored
Normal file
@ -0,0 +1,84 @@
|
||||
# Copyright 2023 Xiaomi Corp. (author: Zengrui Jin)
|
||||
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: run-swbd-conformer_ctc
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
|
||||
concurrency:
|
||||
group: run-swbd-conformer_ctc-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
run-swbd-conformer_ctc:
|
||||
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'swbd'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'pip'
|
||||
cache-dependency-path: '**/requirements-ci.txt'
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||
pip uninstall -y protobuf
|
||||
pip install --no-binary protobuf protobuf==3.20.*
|
||||
|
||||
- name: Cache kaldifeat
|
||||
id: my-cache
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/install-kaldifeat.sh
|
||||
|
||||
- name: Inference with pre-trained model
|
||||
shell: bash
|
||||
env:
|
||||
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||
run: |
|
||||
sudo apt-get -qq install git-lfs tree
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
|
||||
.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh
|
39
.github/workflows/run-yesno-recipe.yml
vendored
39
.github/workflows/run-yesno-recipe.yml
vendored
@ -60,7 +60,7 @@ jobs:
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install
|
||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||
pip uninstall -y protobuf
|
||||
pip install --no-binary protobuf protobuf==3.20.*
|
||||
|
||||
@ -140,9 +140,46 @@ jobs:
|
||||
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
||||
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
||||
|
||||
- name: Test decoding with H
|
||||
shell: bash
|
||||
working-directory: ${{github.workspace}}
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
echo $PYTHONPATH
|
||||
|
||||
cd egs/yesno/ASR
|
||||
python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1
|
||||
|
||||
python3 ./tdnn/jit_pretrained_decode_with_H.py \
|
||||
--nn-model ./tdnn/exp/cpu_jit.pt \
|
||||
--H ./data/lang_phone/H.fst \
|
||||
--tokens ./data/lang_phone/tokens.txt \
|
||||
./download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
||||
./download/waves_yesno/0_0_1_0_0_0_1_0.wav \
|
||||
./download/waves_yesno/0_0_1_0_0_1_1_1.wav
|
||||
|
||||
- name: Test decoding with HL
|
||||
shell: bash
|
||||
working-directory: ${{github.workspace}}
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
echo $PYTHONPATH
|
||||
|
||||
cd egs/yesno/ASR
|
||||
python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1
|
||||
|
||||
python3 ./tdnn/jit_pretrained_decode_with_HL.py \
|
||||
--nn-model ./tdnn/exp/cpu_jit.pt \
|
||||
--HL ./data/lang_phone/HL.fst \
|
||||
--words ./data/lang_phone/words.txt \
|
||||
./download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
||||
./download/waves_yesno/0_0_1_0_0_0_1_0.wav \
|
||||
./download/waves_yesno/0_0_1_0_0_1_1_1.wav
|
||||
|
||||
- name: Show generated files
|
||||
shell: bash
|
||||
working-directory: ${{github.workspace}}
|
||||
run: |
|
||||
cd egs/yesno/ASR
|
||||
ls -lh tdnn/exp
|
||||
ls -lh data/lang_phone
|
||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -34,3 +34,5 @@ node_modules
|
||||
*.param
|
||||
*.bin
|
||||
.DS_Store
|
||||
*.fst
|
||||
*.arpa
|
||||
|
12
README.md
12
README.md
@ -29,6 +29,7 @@ We provide the following recipes:
|
||||
- [yesno][yesno]
|
||||
- [LibriSpeech][librispeech]
|
||||
- [GigaSpeech][gigaspeech]
|
||||
- [AMI][ami]
|
||||
- [Aishell][aishell]
|
||||
- [Aishell2][aishell2]
|
||||
- [Aishell4][aishell4]
|
||||
@ -37,6 +38,7 @@ We provide the following recipes:
|
||||
- [Aidatatang_200zh][aidatatang_200zh]
|
||||
- [WenetSpeech][wenetspeech]
|
||||
- [Alimeeting][alimeeting]
|
||||
- [Switchboard][swbd]
|
||||
- [TAL_CSASR][tal_csasr]
|
||||
|
||||
### yesno
|
||||
@ -118,9 +120,9 @@ We provide a Colab notebook to run a pre-trained transducer conformer + stateles
|
||||
|
||||
| Encoder | Params | test-clean | test-other |
|
||||
|-----------------|--------|------------|------------|
|
||||
| zipformer | 65.5M | 2.21 | 4.91 |
|
||||
| zipformer-small | 23.2M | 2.46 | 5.83 |
|
||||
| zipformer-large | 148.4M | 2.11 | 4.77 |
|
||||
| zipformer | 65.5M | 2.21 | 4.79 |
|
||||
| zipformer-small | 23.2M | 2.42 | 5.73 |
|
||||
| zipformer-large | 148.4M | 2.06 | 4.63 |
|
||||
|
||||
Note: No auxiliary losses are used in the training and no LMs are used
|
||||
in the decoding.
|
||||
@ -338,7 +340,7 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
|
||||
|
||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
|
||||
|
||||
The best results for Chinese CER(%) and English WER(%) respectivly (zh: Chinese, en: English):
|
||||
The best results for Chinese CER(%) and English WER(%) respectively (zh: Chinese, en: English):
|
||||
|decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en |
|
||||
|--|--|--|--|--|--|--|
|
||||
|greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13|
|
||||
@ -393,4 +395,6 @@ Please see: [
|
||||
.. list-table::
|
||||
:widths: 25 50
|
||||
:header-rows: 1
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
.. _icefall_export_to_ncnn:
|
||||
|
||||
Export to ncnn
|
||||
==============
|
||||
|
||||
|
@ -47,7 +47,7 @@ The data preparation contains several stages, you can use the following two
|
||||
options:
|
||||
|
||||
- ``--stage``
|
||||
- ``--stop-stage``
|
||||
- ``--stop_stage``
|
||||
|
||||
to control which stage(s) should be run. By default, all stages are executed.
|
||||
|
||||
@ -56,8 +56,8 @@ For example,
|
||||
.. code-block:: bash
|
||||
|
||||
$ cd egs/librispeech/ASR
|
||||
$ ./prepare.sh --stage 0 --stop-stage 0 # run only stage 0
|
||||
$ ./prepare.sh --stage 2 --stop-stage 5 # run from stage 2 to stage 5
|
||||
$ ./prepare.sh --stage 0 --stop_stage 0 # run only stage 0
|
||||
$ ./prepare.sh --stage 2 --stop_stage 5 # run from stage 2 to stage 5
|
||||
|
||||
.. HINT::
|
||||
|
||||
@ -108,15 +108,15 @@ As usual, you can control the stages you want to run by specifying the following
|
||||
two options:
|
||||
|
||||
- ``--stage``
|
||||
- ``--stop-stage``
|
||||
- ``--stop_stage``
|
||||
|
||||
For example,
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ cd egs/librispeech/ASR
|
||||
$ ./distillation_with_hubert.sh --stage 0 --stop-stage 0 # run only stage 0
|
||||
$ ./distillation_with_hubert.sh --stage 2 --stop-stage 4 # run from stage 2 to stage 5
|
||||
$ ./distillation_with_hubert.sh --stage 0 --stop_stage 0 # run only stage 0
|
||||
$ ./distillation_with_hubert.sh --stage 2 --stop_stage 4 # run from stage 2 to stage 5
|
||||
|
||||
Here are a few options in `./distillation_with_hubert.sh <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/distillation_with_hubert.sh>`_
|
||||
you need to know before you proceed.
|
||||
@ -134,7 +134,7 @@ and prepares MVQ-augmented training manifests.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ ./distillation_with_hubert.sh --stage 2 --stop-stage 2 # run only stage 2
|
||||
$ ./distillation_with_hubert.sh --stage 2 --stop_stage 2 # run only stage 2
|
||||
|
||||
Please see the
|
||||
following screenshot for the output of an example execution.
|
||||
@ -172,7 +172,7 @@ To perform training, please run stage 3 by executing the following command.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ ./prepare.sh --stage 3 --stop-stage 3 # run MVQ training
|
||||
$ ./prepare.sh --stage 3 --stop_stage 3 # run MVQ training
|
||||
|
||||
Here is the code snippet for training:
|
||||
|
||||
|
7
docs/source/recipes/RNN-LM/index.rst
Normal file
7
docs/source/recipes/RNN-LM/index.rst
Normal file
@ -0,0 +1,7 @@
|
||||
RNN-LM
|
||||
======
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
librispeech/lm-training
|
104
docs/source/recipes/RNN-LM/librispeech/lm-training.rst
Normal file
104
docs/source/recipes/RNN-LM/librispeech/lm-training.rst
Normal file
@ -0,0 +1,104 @@
|
||||
.. _train_nnlm:
|
||||
|
||||
Train an RNN langugage model
|
||||
======================================
|
||||
|
||||
If you have enough text data, you can train a neural network language model (NNLM) to improve
|
||||
the WER of your E2E ASR system. This tutorial shows you how to train an RNNLM from
|
||||
scratch.
|
||||
|
||||
.. HINT::
|
||||
|
||||
For how to use an NNLM during decoding, please refer to the following tutorials:
|
||||
:ref:`shallow_fusion`, :ref:`LODR`, :ref:`rescoring`
|
||||
|
||||
.. note::
|
||||
|
||||
This tutorial is based on the LibriSpeech recipe. Please check it out for the necessary
|
||||
python scripts for this tutorial. We use the LibriSpeech LM-corpus as the LM training set
|
||||
for illustration purpose. You can also collect your own data. The data format is quite simple:
|
||||
each line should contain a complete sentence, and words should be separated by space.
|
||||
|
||||
First, let's download the training data for the RNNLM. This can be done via the
|
||||
following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ wget https://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz
|
||||
$ gzip -d librispeech-lm-norm.txt.gz
|
||||
|
||||
As we are training a BPE-level RNNLM, we need to tokenize the training text, which requires a
|
||||
BPE tokenizer. This can be achieved by executing the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ # if you don't have the BPE
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
$ cd icefall-asr-librispeech-zipformer-2023-05-15/data/lang_bpe_500
|
||||
$ git lfs pull --include bpe.model
|
||||
$ cd ../../..
|
||||
|
||||
$ ./local/prepare_lm_training_data.py \
|
||||
--bpe-model icefall-asr-librispeech-zipformer-2023-05-15/data/lang_bpe_500/bpe.model \
|
||||
--lm-data librispeech-lm-norm.txt \
|
||||
--lm-archive data/lang_bpe_500/lm_data.pt
|
||||
|
||||
Now, you should have a file name ``lm_data.pt`` file store under the directory ``data/lang_bpe_500``.
|
||||
This is the packed training data for the RNNLM. We then sort the training data according to its
|
||||
sentence length.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ # This could take a while (~ 20 minutes), feel free to grab a cup of coffee :)
|
||||
$ ./local/sort_lm_training_data.py \
|
||||
--in-lm-data data/lang_bpe_500/lm_data.pt \
|
||||
--out-lm-data data/lang_bpe_500/sorted_lm_data.pt \
|
||||
--out-statistics data/lang_bpe_500/lm_data_stats.txt
|
||||
|
||||
|
||||
The aforementioned steps can be repeated to create a a validation set for you RNNLM. Let's say
|
||||
you have a validation set in ``valid.txt``, you can just set ``--lm-data valid.txt``
|
||||
and ``--lm-archive data/lang_bpe_500/lm-data-valid.pt`` when calling ``./local/prepare_lm_training_data.py``.
|
||||
|
||||
After completing the previous steps, the training and testing sets for training RNNLM are ready.
|
||||
The next step is to train the RNNLM model. The training command is as follows:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ # assume you are in the icefall root directory
|
||||
$ cd rnn_lm
|
||||
$ ln -s ../../egs/librispeech/ASR/data .
|
||||
$ cd ..
|
||||
$ ./rnn_lm/train.py \
|
||||
--world-size 4 \
|
||||
--exp-dir ./rnn_lm/exp \
|
||||
--start-epoch 0 \
|
||||
--num-epochs 10 \
|
||||
--use-fp16 0 \
|
||||
--tie-weights 1 \
|
||||
--embedding-dim 2048 \
|
||||
--hidden_dim 2048 \
|
||||
--num-layers 3 \
|
||||
--batch-size 300 \
|
||||
--lm-data rnn_lm/data/lang_bpe_500/sorted_lm_data.pt \
|
||||
--lm-data-valid rnn_lm/data/lang_bpe_500/sorted_lm_data.pt
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
You can adjust the RNNLM hyper parameters to control the size of the RNNLM,
|
||||
such as embedding dimension and hidden state dimension. For more details, please
|
||||
run ``./rnn_lm/train.py --help``.
|
||||
|
||||
.. note::
|
||||
|
||||
The training of RNNLM can take a long time (usually a couple of days).
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -15,3 +15,4 @@ We may add recipes for other tasks as well in the future.
|
||||
|
||||
Non-streaming-ASR/index
|
||||
Streaming-ASR/index
|
||||
RNN-LM/index
|
||||
|
@ -635,7 +635,6 @@ def train_one_epoch(
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -800,7 +799,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
21
egs/aishell/ASR/conformer_ctc/export.py
Normal file → Executable file
21
egs/aishell/ASR/conformer_ctc/export.py
Normal file → Executable file
@ -23,12 +23,12 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
from icefall.utils import AttributeDict, num_tokens, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -63,11 +63,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/lang_char",
|
||||
help="""It contains language related input files such as "lexicon.txt"
|
||||
""",
|
||||
required=True,
|
||||
help="Path to the tokens.txt.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -98,16 +97,16 @@ def get_params() -> AttributeDict:
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
logging.info(params)
|
||||
# Load tokens.txt here
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
num_classes = num_tokens(token_table) + 1 # +1 for the blank
|
||||
|
||||
logging.info(params)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
|
1
egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_H.py
Symbolic link
1
egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_H.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py
|
1
egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py
Symbolic link
1
egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py
|
1
egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py
Symbolic link
1
egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py
|
0
egs/aishell/ASR/conformer_ctc/test_transformer.py
Normal file → Executable file
0
egs/aishell/ASR/conformer_ctc/test_transformer.py
Normal file → Executable file
1
egs/aishell/ASR/local/prepare_lang_fst.py
Symbolic link
1
egs/aishell/ASR/local/prepare_lang_fst.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/prepare_lang_fst.py
|
@ -15,7 +15,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# You can install sentencepiece via:
|
||||
#
|
||||
# pip install sentencepiece
|
||||
@ -26,12 +25,12 @@
|
||||
# Please install a version >=0.1.96
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
from icefall import byte_encode, tokenize_by_CJK_char
|
||||
|
||||
|
||||
@ -74,6 +73,11 @@ def main():
|
||||
model_type = "unigram"
|
||||
|
||||
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
|
||||
model_file = Path(model_prefix + ".model")
|
||||
if model_file.is_file():
|
||||
print(f"{model_file} exists - skipping")
|
||||
return
|
||||
|
||||
character_coverage = 1.0
|
||||
input_sentence_size = 100000000
|
||||
|
||||
@ -88,8 +92,6 @@ def main():
|
||||
|
||||
_convert_to_bchar(args.transcript, train_text)
|
||||
|
||||
model_file = Path(model_prefix + ".model")
|
||||
if not model_file.is_file():
|
||||
spm.SentencePieceTrainer.train(
|
||||
input=train_text,
|
||||
vocab_size=vocab_size,
|
||||
@ -102,9 +104,6 @@ def main():
|
||||
bos_id=-1,
|
||||
eos_id=-1,
|
||||
)
|
||||
else:
|
||||
print(f"{model_file} exists - skipping")
|
||||
return
|
||||
|
||||
shutil.copyfile(model_file, f"{lang_dir}/bbpe.model")
|
||||
|
||||
|
@ -143,6 +143,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
./local/prepare_lang.py --lang-dir $lang_phone_dir
|
||||
fi
|
||||
|
||||
|
||||
# Train a bigram P for MMI training
|
||||
if [ ! -f $lang_phone_dir/transcript_words.txt ]; then
|
||||
log "Generate data to train phone based bigram P"
|
||||
@ -203,6 +204,10 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
|
||||
./local/prepare_char.py --lang-dir $lang_char_dir
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_char_dir/HLG.fst ]; then
|
||||
./local/prepare_lang_fst.py --lang-dir $lang_phone_dir --ngram-G ./data/lm/G_3_gram.fst.txt
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
|
@ -872,7 +872,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -1045,7 +1045,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -322,6 +322,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -151,12 +151,14 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -170,6 +172,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -1028,7 +1028,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -1031,7 +1031,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -1019,7 +1019,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -198,7 +198,7 @@ class AishellAsrDataModule:
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
@ -730,7 +730,6 @@ def train_one_epoch(
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -919,7 +918,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -908,7 +908,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -635,7 +635,6 @@ def train_one_epoch(
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -800,7 +799,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -999,7 +999,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -988,7 +988,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -330,6 +330,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -152,12 +152,14 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -171,6 +173,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -1019,7 +1019,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -1074,7 +1074,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -1075,7 +1075,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -75,7 +75,7 @@ See <https://github.com/k2-fsa/icefall/pull/1058> for more details.
|
||||
##### normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M
|
||||
|
||||
The tensorboard log can be found at
|
||||
<https://tensorboard.dev/experiment/cBaoIabCQxSDsyZM7FzqZA/>
|
||||
<https://tensorboard.dev/experiment/R2DT9Ju4QiadC4e2ioKh5A/>
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding results at:
|
||||
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15>
|
||||
@ -90,18 +90,20 @@ You can use <https://github.com/k2-fsa/sherpa> to deploy it.
|
||||
| greedy_search | 2.23 | 4.96 | --epoch 40 --avg 16 |
|
||||
| modified_beam_search | 2.21 | 4.91 | --epoch 40 --avg 16 |
|
||||
| fast_beam_search | 2.24 | 4.93 | --epoch 40 --avg 16 |
|
||||
| greedy_search | 2.22 | 4.87 | --epoch 50 --avg 25 |
|
||||
| modified_beam_search | 2.21 | 4.79 | --epoch 50 --avg 25 |
|
||||
| fast_beam_search | 2.21 | 4.82 | --epoch 50 --avg 25 |
|
||||
| modified_beam_search_shallow_fusion | 2.01 | 4.37 | --epoch 40 --avg 16 --beam-size 12 --lm-scale 0.3 |
|
||||
| modified_beam_search_LODR | 1.94 | 4.17 | --epoch 40 --avg 16 --beam-size 12 --lm-scale 0.52 --LODR-scale -0.26 |
|
||||
| modified_beam_search_rescore | 2.04 | 4.39 | --epoch 40 --avg 16 --beam-size 12 |
|
||||
| modified_beam_search_rescore_LODR | 2.01 | 4.33 | --epoch 40 --avg 16 --beam-size 12 |
|
||||
|
||||
|
||||
The training command is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
./zipformer/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 40 \
|
||||
--num-epochs 50 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp \
|
||||
@ -115,8 +117,8 @@ The decoding command is:
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
for m in greedy_search modified_beam_search fast_beam_search; do
|
||||
./zipformer/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--epoch 50 \
|
||||
--avg 25 \
|
||||
--use-averaged-model 1 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
@ -129,7 +131,7 @@ To decode with external language models, please refer to the documentation [here
|
||||
##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M
|
||||
|
||||
The tensorboard log can be found at
|
||||
<https://tensorboard.dev/experiment/53P4tL22TpO0UdiL0kPaLg/>
|
||||
<https://tensorboard.dev/experiment/M9C8cYPWSN2MVBYaBIX3EQ/>
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding results at:
|
||||
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-small-2023-05-16>
|
||||
@ -144,13 +146,16 @@ You can use <https://github.com/k2-fsa/sherpa> to deploy it.
|
||||
| greedy_search | 2.49 | 5.91 | --epoch 40 --avg 13 |
|
||||
| modified_beam_search | 2.46 | 5.83 | --epoch 40 --avg 13 |
|
||||
| fast_beam_search | 2.46 | 5.87 | --epoch 40 --avg 13 |
|
||||
| greedy_search | 2.46 | 5.86 | --epoch 50 --avg 23 |
|
||||
| modified_beam_search | 2.42 | 5.73 | --epoch 50 --avg 23 |
|
||||
| fast_beam_search | 2.46 | 5.78 | --epoch 50 --avg 23 |
|
||||
|
||||
The training command is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1"
|
||||
./zipformer/train.py \
|
||||
--world-size 2 \
|
||||
--num-epochs 40 \
|
||||
--num-epochs 50 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp-small \
|
||||
@ -169,8 +174,8 @@ The decoding command is:
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
for m in greedy_search modified_beam_search fast_beam_search; do
|
||||
./zipformer/decode.py \
|
||||
--epoch 40 \
|
||||
--avg 13 \
|
||||
--epoch 50 \
|
||||
--avg 23 \
|
||||
--exp-dir zipformer/exp-small \
|
||||
--max-duration 600 \
|
||||
--causal 0 \
|
||||
@ -185,7 +190,7 @@ done
|
||||
##### large-scaled model, number of model parameters: 148439574, i.e., 148.4 M
|
||||
|
||||
The tensorboard log can be found at
|
||||
<https://tensorboard.dev/experiment/HJ74wWYpQAGSzETkmQnrmQ/>
|
||||
<https://tensorboard.dev/experiment/C5ZPE5u1So2ZwhYLKW0FVg/>
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding results at:
|
||||
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-large-2023-05-16>
|
||||
@ -200,13 +205,16 @@ You can use <https://github.com/k2-fsa/sherpa> to deploy it.
|
||||
| greedy_search | 2.12 | 4.8 | --epoch 40 --avg 13 |
|
||||
| modified_beam_search | 2.11 | 4.7 | --epoch 40 --avg 13 |
|
||||
| fast_beam_search | 2.13 | 4.78 | --epoch 40 --avg 13 |
|
||||
| greedy_search | 2.08 | 4.69 | --epoch 50 --avg 30 |
|
||||
| modified_beam_search | 2.06 | 4.63 | --epoch 50 --avg 30 |
|
||||
| fast_beam_search | 2.09 | 4.68 | --epoch 50 --avg 30 |
|
||||
|
||||
The training command is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
./zipformer/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 40 \
|
||||
--num-epochs 50 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp-large \
|
||||
@ -224,8 +232,8 @@ The decoding command is:
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
for m in greedy_search modified_beam_search fast_beam_search; do
|
||||
./zipformer/decode.py \
|
||||
--epoch 40 \
|
||||
--avg 16 \
|
||||
--epoch 50 \
|
||||
--avg 30 \
|
||||
--exp-dir zipformer/exp-large \
|
||||
--max-duration 600 \
|
||||
--causal 0 \
|
||||
|
247
egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py
Executable file
247
egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py
Executable file
@ -0,0 +1,247 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This file shows how to use a torchscript model for decoding with H
|
||||
on CPU using OpenFST and decoders from kaldi.
|
||||
|
||||
Usage:
|
||||
|
||||
(1) LibriSpeech conformer_ctc
|
||||
|
||||
./conformer_ctc/jit_pretrained_decode_with_H.py \
|
||||
--nn-model ./conformer_ctc/exp/cpu_jit.pt \
|
||||
--H ./data/lang_bpe_500/H.fst \
|
||||
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||
./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \
|
||||
./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac
|
||||
|
||||
|
||||
(2) AIShell conformer_ctc
|
||||
|
||||
./conformer_ctc/jit_pretrained_decode_with_H.py \
|
||||
--nn-model ./conformer_ctc/exp/cpu_jit.pt \
|
||||
--H ./data/lang_char/H.fst \
|
||||
--tokens ./data/lang_char/tokens.txt \
|
||||
./BAC009S0764W0121.wav \
|
||||
./BAC009S0764W0122.wav \
|
||||
./BAC009S0764W0123.wav
|
||||
|
||||
Note that to generate ./conformer_ctc/exp/cpu_jit.pt,
|
||||
you can use ./export.py --jit 1
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, List
|
||||
|
||||
import kaldifeat
|
||||
import kaldifst
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nn-model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="""Path to the torchscript model.
|
||||
You can use ./conformer_ctc/export.py --jit 1
|
||||
to obtain it
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument("--H", type=str, required=True, help="Path to H.fst")
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. ",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_tokens(tokens_txt: str) -> Dict[int, str]:
|
||||
id2token = dict()
|
||||
with open(tokens_txt, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
token, idx = line.strip().split()
|
||||
id2token[int(idx)] = token
|
||||
|
||||
return id2token
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
if sample_rate != expected_sample_rate:
|
||||
wave = torchaudio.functional.resample(
|
||||
wave,
|
||||
orig_freq=sample_rate,
|
||||
new_freq=expected_sample_rate,
|
||||
)
|
||||
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
def decode(
|
||||
filename: str,
|
||||
nnet_output: torch.Tensor,
|
||||
H: kaldifst,
|
||||
id2token: Dict[int, str],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Args:
|
||||
filename:
|
||||
Path to the filename for decoding. Used for debugging.
|
||||
nnet_output:
|
||||
A 2-D float32 tensor of shape (num_frames, vocab_size). It
|
||||
contains output from log_softmax.
|
||||
H:
|
||||
The H graph.
|
||||
id2token:
|
||||
A map mapping token ID to token string.
|
||||
Returns:
|
||||
Return a list of decoded tokens.
|
||||
"""
|
||||
logging.info(f"{filename}, {nnet_output.shape}")
|
||||
decodable = DecodableCtc(nnet_output.cpu())
|
||||
|
||||
decoder_opts = FasterDecoderOptions(max_active=3000)
|
||||
decoder = FasterDecoder(H, decoder_opts)
|
||||
decoder.decode(decodable)
|
||||
|
||||
if not decoder.reached_final():
|
||||
logging.info(f"failed to decode {filename}")
|
||||
return [""]
|
||||
|
||||
ok, best_path = decoder.get_best_path()
|
||||
|
||||
(
|
||||
ok,
|
||||
isymbols_out,
|
||||
osymbols_out,
|
||||
total_weight,
|
||||
) = kaldifst.get_linear_symbol_sequence(best_path)
|
||||
if not ok:
|
||||
logging.info(f"failed to get linear symbol sequence for {filename}")
|
||||
return [""]
|
||||
|
||||
# tokens are incremented during graph construction
|
||||
# so they need to be decremented
|
||||
hyps = [id2token[i - 1] for i in osymbols_out]
|
||||
# hyps = "".join(hyps).split("▁")
|
||||
hyps = "".join(hyps).split("\u2581") # unicode codepoint of ▁
|
||||
|
||||
return hyps
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Loading torchscript model")
|
||||
model = torch.jit.load(args.nn_model)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
logging.info(f"Loading H from {args.H}")
|
||||
H = kaldifst.StdVectorFst.read(args.H)
|
||||
|
||||
sample_rate = 16000
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files, expected_sample_rate=sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.shape[0] for f in features]
|
||||
feature_lengths = torch.tensor(feature_lengths)
|
||||
|
||||
supervisions = dict()
|
||||
supervisions["sequence_idx"] = torch.arange(len(features))
|
||||
supervisions["start_frame"] = torch.zeros(len(features))
|
||||
supervisions["num_frames"] = feature_lengths
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
nnet_output, _, _ = model(features, supervisions)
|
||||
feature_lengths = ((feature_lengths - 1) // 2 - 1) // 2
|
||||
|
||||
id2token = read_tokens(args.tokens)
|
||||
|
||||
hyps = []
|
||||
for i in range(nnet_output.shape[0]):
|
||||
hyp = decode(
|
||||
filename=args.sound_files[i],
|
||||
nnet_output=nnet_output[i, : feature_lengths[i]],
|
||||
H=H,
|
||||
id2token=id2token,
|
||||
)
|
||||
hyps.append(hyp)
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
244
egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py
Executable file
244
egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py
Executable file
@ -0,0 +1,244 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This file shows how to use a torchscript model for decoding with HL
|
||||
on CPU using OpenFST and decoders from kaldi.
|
||||
|
||||
Usage:
|
||||
|
||||
(1) LibriSpeech conformer_ctc
|
||||
|
||||
./conformer_ctc/jit_pretrained_decode_with_HL.py \
|
||||
--nn-model ./conformer_ctc/exp/cpu_jit.pt \
|
||||
--HL ./data/lang_bpe_500/HL.fst \
|
||||
--words ./data/lang_bpe_500/words.txt \
|
||||
./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \
|
||||
./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac
|
||||
|
||||
(2) AIShell conformer_ctc
|
||||
|
||||
./conformer_ctc/jit_pretrained_decode_with_HL.py \
|
||||
--nn-model ./conformer_ctc/exp/cpu_jit.pt \
|
||||
--HL ./data/lang_char/HL.fst \
|
||||
--words ./data/lang_char/words.txt \
|
||||
./BAC009S0764W0121.wav \
|
||||
./BAC009S0764W0122.wav \
|
||||
./BAC009S0764W0123.wav
|
||||
|
||||
|
||||
Note that to generate ./conformer_ctc/exp/cpu_jit.pt,
|
||||
you can use ./export.py --jit 1
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, List
|
||||
|
||||
import kaldifeat
|
||||
import kaldifst
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nn-model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="""Path to the torchscript model.
|
||||
You can use ./conformer_ctc/export.py --jit 1
|
||||
to obtain it
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--words",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to words.txt",
|
||||
)
|
||||
|
||||
parser.add_argument("--HL", type=str, required=True, help="Path to HL.fst")
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. ",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_words(words_txt: str) -> Dict[int, str]:
|
||||
id2word = dict()
|
||||
with open(words_txt, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
word, idx = line.strip().split()
|
||||
id2word[int(idx)] = word
|
||||
|
||||
return id2word
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
if sample_rate != expected_sample_rate:
|
||||
wave = torchaudio.functional.resample(
|
||||
wave,
|
||||
orig_freq=sample_rate,
|
||||
new_freq=expected_sample_rate,
|
||||
)
|
||||
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
def decode(
|
||||
filename: str,
|
||||
nnet_output: torch.Tensor,
|
||||
HL: kaldifst,
|
||||
id2word: Dict[int, str],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Args:
|
||||
filename:
|
||||
Path to the filename for decoding. Used for debugging.
|
||||
nnet_output:
|
||||
A 2-D float32 tensor of shape (num_frames, vocab_size). It
|
||||
contains output from log_softmax.
|
||||
HL:
|
||||
The HL graph.
|
||||
id2word:
|
||||
A map mapping word ID to word string.
|
||||
Returns:
|
||||
Return a list of decoded words.
|
||||
"""
|
||||
logging.info(f"{filename}, {nnet_output.shape}")
|
||||
decodable = DecodableCtc(nnet_output.cpu())
|
||||
|
||||
decoder_opts = FasterDecoderOptions(max_active=3000)
|
||||
decoder = FasterDecoder(HL, decoder_opts)
|
||||
decoder.decode(decodable)
|
||||
|
||||
if not decoder.reached_final():
|
||||
logging.info(f"failed to decode {filename}")
|
||||
return [""]
|
||||
|
||||
ok, best_path = decoder.get_best_path()
|
||||
|
||||
(
|
||||
ok,
|
||||
isymbols_out,
|
||||
osymbols_out,
|
||||
total_weight,
|
||||
) = kaldifst.get_linear_symbol_sequence(best_path)
|
||||
if not ok:
|
||||
logging.info(f"failed to get linear symbol sequence for {filename}")
|
||||
return [""]
|
||||
|
||||
# are shifted by 1 during graph construction
|
||||
hyps = [id2word[i] for i in osymbols_out]
|
||||
|
||||
return hyps
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Loading torchscript model")
|
||||
model = torch.jit.load(args.nn_model)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
logging.info(f"Loading HL from {args.HL}")
|
||||
HL = kaldifst.StdVectorFst.read(args.HL)
|
||||
|
||||
sample_rate = 16000
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files, expected_sample_rate=sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.shape[0] for f in features]
|
||||
feature_lengths = torch.tensor(feature_lengths)
|
||||
|
||||
supervisions = dict()
|
||||
supervisions["sequence_idx"] = torch.arange(len(features))
|
||||
supervisions["start_frame"] = torch.zeros(len(features))
|
||||
supervisions["num_frames"] = feature_lengths
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
nnet_output, _, _ = model(features, supervisions)
|
||||
feature_lengths = ((feature_lengths - 1) // 2 - 1) // 2
|
||||
|
||||
id2word = read_words(args.words)
|
||||
|
||||
hyps = []
|
||||
for i in range(nnet_output.shape[0]):
|
||||
hyp = decode(
|
||||
filename=args.sound_files[i],
|
||||
nnet_output=nnet_output[i, : feature_lengths[i]],
|
||||
HL=HL,
|
||||
id2word=id2word,
|
||||
)
|
||||
hyps.append(hyp)
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
243
egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py
Executable file
243
egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py
Executable file
@ -0,0 +1,243 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This file shows how to use a torchscript model for decoding with HLG
|
||||
on CPU using OpenFST and decoders from kaldi.
|
||||
|
||||
Usage:
|
||||
|
||||
(1) LibriSpeech conformer_ctc
|
||||
|
||||
./conformer_ctc/jit_pretrained_decode_with_HLG.py \
|
||||
--nn-model ./conformer_ctc/exp/cpu_jit.pt \
|
||||
--HLG ./data/lang_bpe_500/HLG.fst \
|
||||
--words ./data/lang_bpe_500/words.txt \
|
||||
./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \
|
||||
./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac
|
||||
|
||||
(2) AIShell conformer_ctc
|
||||
|
||||
./conformer_ctc/jit_pretrained_decode_with_HLG.py \
|
||||
--nn-model ./conformer_ctc/exp/cpu_jit.pt \
|
||||
--HLG ./data/lang_char/HLG.fst \
|
||||
--words ./data/lang_char/words.txt \
|
||||
./BAC009S0764W0121.wav \
|
||||
./BAC009S0764W0122.wav \
|
||||
./BAC009S0764W0123.wav
|
||||
|
||||
Note that to generate ./conformer_ctc/exp/cpu_jit.pt,
|
||||
you can use ./export.py --jit 1
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, List
|
||||
|
||||
import kaldifeat
|
||||
import kaldifst
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nn-model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="""Path to the torchscript model.
|
||||
You can use ./conformer_ctc/export.py --jit 1
|
||||
to obtain it
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--words",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to words.txt",
|
||||
)
|
||||
|
||||
parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.fst")
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. ",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_words(words_txt: str) -> Dict[int, str]:
|
||||
id2word = dict()
|
||||
with open(words_txt, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
word, idx = line.strip().split()
|
||||
id2word[int(idx)] = word
|
||||
|
||||
return id2word
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
if sample_rate != expected_sample_rate:
|
||||
wave = torchaudio.functional.resample(
|
||||
wave,
|
||||
orig_freq=sample_rate,
|
||||
new_freq=expected_sample_rate,
|
||||
)
|
||||
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
def decode(
|
||||
filename: str,
|
||||
nnet_output: torch.Tensor,
|
||||
HLG: kaldifst,
|
||||
id2word: Dict[int, str],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Args:
|
||||
filename:
|
||||
Path to the filename for decoding. Used for debugging.
|
||||
nnet_output:
|
||||
A 2-D float32 tensor of shape (num_frames, vocab_size). It
|
||||
contains output from log_softmax.
|
||||
HLG:
|
||||
The HLG graph.
|
||||
id2word:
|
||||
A map mapping word ID to word string.
|
||||
Returns:
|
||||
Return a list of decoded words.
|
||||
"""
|
||||
logging.info(f"{filename}, {nnet_output.shape}")
|
||||
decodable = DecodableCtc(nnet_output.cpu())
|
||||
|
||||
decoder_opts = FasterDecoderOptions(max_active=3000)
|
||||
decoder = FasterDecoder(HLG, decoder_opts)
|
||||
decoder.decode(decodable)
|
||||
|
||||
if not decoder.reached_final():
|
||||
logging.info(f"failed to decode {filename}")
|
||||
return [""]
|
||||
|
||||
ok, best_path = decoder.get_best_path()
|
||||
|
||||
(
|
||||
ok,
|
||||
isymbols_out,
|
||||
osymbols_out,
|
||||
total_weight,
|
||||
) = kaldifst.get_linear_symbol_sequence(best_path)
|
||||
if not ok:
|
||||
logging.info(f"failed to get linear symbol sequence for {filename}")
|
||||
return [""]
|
||||
|
||||
# are shifted by 1 during graph construction
|
||||
hyps = [id2word[i] for i in osymbols_out]
|
||||
|
||||
return hyps
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Loading torchscript model")
|
||||
model = torch.jit.load(args.nn_model)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
logging.info(f"Loading HLG from {args.HLG}")
|
||||
HLG = kaldifst.StdVectorFst.read(args.HLG)
|
||||
|
||||
sample_rate = 16000
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files, expected_sample_rate=sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.shape[0] for f in features]
|
||||
feature_lengths = torch.tensor(feature_lengths)
|
||||
|
||||
supervisions = dict()
|
||||
supervisions["sequence_idx"] = torch.arange(len(features))
|
||||
supervisions["start_frame"] = torch.zeros(len(features))
|
||||
supervisions["num_frames"] = feature_lengths
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
nnet_output, _, _ = model(features, supervisions)
|
||||
feature_lengths = ((feature_lengths - 1) // 2 - 1) // 2
|
||||
|
||||
id2word = read_words(args.words)
|
||||
|
||||
hyps = []
|
||||
for i in range(nnet_output.shape[0]):
|
||||
hyp = decode(
|
||||
filename=args.sound_files[i],
|
||||
nnet_output=nnet_output[i, : feature_lengths[i]],
|
||||
HLG=HLG,
|
||||
id2word=id2word,
|
||||
)
|
||||
hyps.append(hyp)
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -557,7 +557,6 @@ def train_one_epoch(
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
|
||||
if tb_writer is not None:
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
|
@ -953,7 +953,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -401,6 +401,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -136,6 +136,7 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
@ -184,6 +185,7 @@ class OnnxModel:
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -197,6 +199,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -953,7 +953,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -955,7 +955,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -56,6 +56,8 @@ use_extracted_codebook=True
|
||||
# "hubert_xtralarge_ll60k" -> pretrained model without fintuing
|
||||
teacher_model_id=hubert_xtralarge_ll60k_finetune_ls960
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
|
@ -43,6 +43,7 @@ from pathlib import Path
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
# This function is copied from lhotse
|
||||
def tqdm_urlretrieve_hook(t):
|
||||
"""Wraps tqdm instance.
|
||||
|
219
egs/librispeech/ASR/local/prepare_lang_fst.py
Executable file
219
egs/librispeech/ASR/local/prepare_lang_fst.py
Executable file
@ -0,0 +1,219 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script takes as input lang_dir containing lexicon_disambig.txt,
|
||||
tokens.txt, and words.txt and generates the following files:
|
||||
|
||||
- H.fst
|
||||
- HL.fst
|
||||
- HLG.fst
|
||||
|
||||
Note that saved files are in OpenFst binary format.
|
||||
|
||||
Usage:
|
||||
|
||||
./local/prepare_lang_fst.py \
|
||||
--lang-dir ./data/lang_phone \
|
||||
--has-silence 1
|
||||
|
||||
Or
|
||||
|
||||
./local/prepare_lang_fst.py \
|
||||
--lang-dir ./data/lang_bpe_500
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import kaldifst
|
||||
|
||||
from icefall.ctc import (
|
||||
Lexicon,
|
||||
add_disambig_self_loops,
|
||||
add_one,
|
||||
build_standard_ctc_topo,
|
||||
make_lexicon_fst_no_silence,
|
||||
make_lexicon_fst_with_silence,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--has-silence",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="True if the lexicon has silence.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-G",
|
||||
type=str,
|
||||
help="""If not empty, it is the filename of G used to build HLG.
|
||||
For instance, --ngram-G=./data/lm/G_3_fst.txt
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_HL(
|
||||
H: kaldifst.StdVectorFst,
|
||||
L: kaldifst.StdVectorFst,
|
||||
has_silence: bool,
|
||||
lexicon: Lexicon,
|
||||
) -> kaldifst.StdVectorFst:
|
||||
if has_silence:
|
||||
# We also need to change the input labels of L
|
||||
add_one(L, treat_ilabel_zero_specially=True, update_olabel=False)
|
||||
else:
|
||||
add_one(L, treat_ilabel_zero_specially=False, update_olabel=False)
|
||||
|
||||
# Invoke add_disambig_self_loops() so that it eats the disambig symbols
|
||||
# from L after composition
|
||||
add_disambig_self_loops(
|
||||
H,
|
||||
start=lexicon.token2id["#0"] + 1,
|
||||
end=lexicon.max_disambig_id + 1,
|
||||
)
|
||||
|
||||
kaldifst.arcsort(H, sort_type="olabel")
|
||||
kaldifst.arcsort(L, sort_type="ilabel")
|
||||
|
||||
HL = kaldifst.compose(H, L)
|
||||
kaldifst.determinize_star(HL)
|
||||
|
||||
disambig0 = lexicon.token2id["#0"] + 1
|
||||
max_disambig = lexicon.max_disambig_id + 1
|
||||
for state in kaldifst.StateIterator(HL):
|
||||
for arc in kaldifst.ArcIterator(HL, state):
|
||||
# If treat_ilabel_zero_specially is False, we always change it
|
||||
# Otherwise, we only change non-zero input labels
|
||||
if disambig0 <= arc.ilabel <= max_disambig:
|
||||
arc.ilabel = 0
|
||||
|
||||
# Note: We are not composing L with G, so there is no need to add
|
||||
# self-loops to L to handle #0
|
||||
|
||||
return HL
|
||||
|
||||
|
||||
def build_HLG(
|
||||
H: kaldifst.StdVectorFst,
|
||||
L: kaldifst.StdVectorFst,
|
||||
G: kaldifst.StdVectorFst,
|
||||
has_silence: bool,
|
||||
lexicon: Lexicon,
|
||||
) -> kaldifst.StdVectorFst:
|
||||
if has_silence:
|
||||
# We also need to change the input labels of L
|
||||
add_one(L, treat_ilabel_zero_specially=True, update_olabel=False)
|
||||
else:
|
||||
add_one(L, treat_ilabel_zero_specially=False, update_olabel=False)
|
||||
|
||||
# add-self-loops
|
||||
token_disambig0 = lexicon.token2id["#0"] + 1
|
||||
word_disambig0 = lexicon.word2id["#0"]
|
||||
|
||||
kaldifst.add_self_loops(L, isyms=[token_disambig0], osyms=[word_disambig0])
|
||||
|
||||
kaldifst.arcsort(L, sort_type="olabel")
|
||||
kaldifst.arcsort(G, sort_type="ilabel")
|
||||
LG = kaldifst.compose(L, G)
|
||||
kaldifst.determinize_star(LG)
|
||||
kaldifst.minimize_encoded(LG)
|
||||
|
||||
kaldifst.arcsort(LG, sort_type="ilabel")
|
||||
|
||||
# Invoke add_disambig_self_loops() so that it eats the disambig symbols
|
||||
# from L after composition
|
||||
add_disambig_self_loops(
|
||||
H,
|
||||
start=lexicon.token2id["#0"] + 1,
|
||||
end=lexicon.max_disambig_id + 1,
|
||||
)
|
||||
|
||||
kaldifst.arcsort(H, sort_type="olabel")
|
||||
|
||||
HLG = kaldifst.compose(H, LG)
|
||||
kaldifst.determinize_star(HLG)
|
||||
|
||||
disambig0 = lexicon.token2id["#0"] + 1
|
||||
max_disambig = lexicon.max_disambig_id + 1
|
||||
for state in kaldifst.StateIterator(HLG):
|
||||
for arc in kaldifst.ArcIterator(HLG, state):
|
||||
# If treat_ilabel_zero_specially is False, we always change it
|
||||
# Otherwise, we only change non-zero input labels
|
||||
if disambig0 <= arc.ilabel <= max_disambig:
|
||||
arc.ilabel = 0
|
||||
return HLG
|
||||
|
||||
|
||||
def copy_fst(fst):
|
||||
# Please don't use fst.copy()
|
||||
return kaldifst.StdVectorFst(fst)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = args.lang_dir
|
||||
|
||||
lexicon = Lexicon(lang_dir)
|
||||
|
||||
logging.info("Building standard CTC topology")
|
||||
max_token_id = max(lexicon.tokens)
|
||||
H = build_standard_ctc_topo(max_token_id=max_token_id)
|
||||
|
||||
# We need to add one to all tokens since we want to use ID 0
|
||||
# for epsilon
|
||||
add_one(H, treat_ilabel_zero_specially=False, update_olabel=True)
|
||||
H.write(f"{lang_dir}/H.fst")
|
||||
|
||||
logging.info("Building L")
|
||||
# Now for HL
|
||||
|
||||
if args.has_silence:
|
||||
L = make_lexicon_fst_with_silence(lexicon, attach_symbol_table=False)
|
||||
else:
|
||||
L = make_lexicon_fst_no_silence(lexicon, attach_symbol_table=False)
|
||||
|
||||
logging.info("Building HL")
|
||||
HL = build_HL(
|
||||
H=copy_fst(H),
|
||||
L=copy_fst(L),
|
||||
has_silence=args.has_silence,
|
||||
lexicon=lexicon,
|
||||
)
|
||||
HL.write(f"{lang_dir}/HL.fst")
|
||||
|
||||
if not args.ngram_G:
|
||||
logging.info("Skip building HLG")
|
||||
return
|
||||
|
||||
logging.info("Building HLG")
|
||||
with open(args.ngram_G) as f:
|
||||
G = kaldifst.compile(
|
||||
s=f.read(),
|
||||
acceptor=False,
|
||||
)
|
||||
|
||||
HLG = build_HLG(H=H, L=L, G=G, has_silence=args.has_silence, lexicon=lexicon)
|
||||
HLG.write(f"{lang_dir}/HLG.fst")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -236,7 +236,7 @@ def greedy_search_batch(
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
offset = 0
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
for t, batch_size in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
@ -507,7 +507,7 @@ def modified_beam_search(
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
for t, batch_size in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
|
@ -162,7 +162,6 @@ def merge_chunks(
|
||||
|
||||
futures = []
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
|
||||
for cut in cuts_chunk:
|
||||
cur_rec_id = cut.recording.id
|
||||
if len(cut_list) == 0:
|
||||
|
@ -264,6 +264,7 @@ def decode_dataset(
|
||||
- timestamps of reference transcript
|
||||
- timestamps of predicted result
|
||||
"""
|
||||
|
||||
# Background worker to add alignemnt and save cuts to disk.
|
||||
def _save_worker(
|
||||
cuts: List[Cut],
|
||||
|
@ -57,8 +57,7 @@ def test_model():
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
|
||||
if not os.path.exists(params.exp_dir):
|
||||
os.path.mkdir(params.exp_dir)
|
||||
params.exp_dir.mkdir(exist_ok=True)
|
||||
|
||||
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
||||
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||
|
@ -359,6 +359,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -356,6 +356,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -129,6 +129,7 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
@ -166,6 +167,7 @@ class OnnxModel:
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -179,6 +181,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -172,30 +172,35 @@ class Model:
|
||||
self.encoder = ort.InferenceSession(
|
||||
args.encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_decoder(self, args):
|
||||
self.decoder = ort.InferenceSession(
|
||||
args.decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_joiner(self, args):
|
||||
self.joiner = ort.InferenceSession(
|
||||
args.joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_joiner_encoder_proj(self, args):
|
||||
self.joiner_encoder_proj = ort.InferenceSession(
|
||||
args.joiner_encoder_proj_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_joiner_decoder_proj(self, args):
|
||||
self.joiner_decoder_proj = ort.InferenceSession(
|
||||
args.joiner_decoder_proj_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
|
@ -242,6 +242,10 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
$lang_dir/L_disambig.pt \
|
||||
$lang_dir/L_disambig.fst
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/HL.fst ]; then
|
||||
./local/prepare_lang_fst.py --lang-dir $lang_dir --ngram-G ./data/lm/G_3_gram.fst.txt
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
|
@ -66,7 +66,6 @@ class Eve(Optimizer):
|
||||
weight_decay=1e-3,
|
||||
target_rms=0.1,
|
||||
):
|
||||
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
|
@ -811,7 +811,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -307,6 +307,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -719,7 +719,7 @@ def greedy_search_batch(
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
offset = 0
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
for t, batch_size in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
@ -1019,7 +1019,7 @@ def modified_beam_search(
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
for t, batch_size in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
@ -1227,7 +1227,7 @@ def modified_beam_search_lm_rescore(
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
for t, batch_size in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
@ -1427,7 +1427,7 @@ def modified_beam_search_lm_rescore_LODR(
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
for t, batch_size in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
@ -2608,7 +2608,6 @@ def modified_beam_search_LODR(
|
||||
context_score = 0
|
||||
new_context_state = None if context_graph is None else hyp.context_state
|
||||
if new_token not in (blank_id, unk_id):
|
||||
|
||||
if context_graph is not None:
|
||||
(
|
||||
context_score,
|
||||
@ -2758,7 +2757,7 @@ def modified_beam_search_lm_shallow_fusion(
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
for t, batch_size in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end] # get batch
|
||||
@ -2900,7 +2899,6 @@ def modified_beam_search_lm_shallow_fusion(
|
||||
new_token = topk_token_indexes[k]
|
||||
new_timestamp = hyp.timestamp[:]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
|
||||
ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
|
||||
|
@ -66,7 +66,6 @@ class Eve(Optimizer):
|
||||
weight_decay=1e-3,
|
||||
target_rms=0.1,
|
||||
):
|
||||
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
|
@ -528,7 +528,6 @@ class ScaledLSTM(nn.LSTM):
|
||||
return
|
||||
|
||||
with torch.cuda.device_of(first_fw):
|
||||
|
||||
# Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
|
||||
# an inplace operation on self._flat_weights
|
||||
with torch.no_grad():
|
||||
|
@ -312,6 +312,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -150,12 +150,14 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -169,6 +171,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -78,6 +78,7 @@ def test_conv2d_subsampling():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -133,6 +134,7 @@ def test_rel_pos():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -220,6 +222,7 @@ def test_conformer_encoder_layer():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -304,6 +307,7 @@ def test_conformer_encoder():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -359,6 +363,7 @@ def test_conformer():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
|
@ -404,6 +404,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -335,6 +335,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -138,6 +138,7 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
@ -185,6 +186,7 @@ class OnnxModel:
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -198,6 +200,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -1003,7 +1003,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -56,7 +56,6 @@ class CodebookIndexExtractor:
|
||||
"""
|
||||
|
||||
def __init__(self, params: AttributeDict):
|
||||
|
||||
self.params = params
|
||||
params.subsets = ["clean-100"]
|
||||
if self.params.full_libri:
|
||||
|
@ -111,7 +111,7 @@ def batch_force_alignment(
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
for t, batch_size in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
|
@ -71,6 +71,10 @@ class Decoder(nn.Module):
|
||||
groups=decoder_dim // 4, # group size == 4
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
|
||||
# when inference with torch.jit.script and context_size == 1
|
||||
self.conv = nn.Identity()
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
|
@ -329,6 +329,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -1132,7 +1132,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -117,7 +117,7 @@ class BatchedOptimizer(Optimizer):
|
||||
|
||||
yield tuples # <-- calling code will do the actual optimization here!
|
||||
|
||||
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
||||
for (stacked_params, _state, _names), batch in zip(tuples, batches):
|
||||
for i, p in enumerate(batch): # batch is list of Parameter
|
||||
p.copy_(stacked_params[i])
|
||||
|
||||
@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer):
|
||||
parameters_names=None,
|
||||
show_dominant_parameters=True,
|
||||
):
|
||||
|
||||
assert parameters_names is not None, (
|
||||
"Please prepare parameters_names,"
|
||||
"which is a List[List[str]]. Each List[str] is for a group"
|
||||
@ -224,9 +223,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
batch = True
|
||||
|
||||
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
||||
|
||||
with self.batched_params(group["params"], group_params_names) as batches:
|
||||
|
||||
# batches is list of pairs (stacked_param, state). stacked_param is like
|
||||
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
||||
# a stacking dim, it is not a real dim.
|
||||
@ -325,7 +322,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
clipping_update_period = group["clipping_update_period"]
|
||||
|
||||
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
||||
for (p, state, param_names) in tuples:
|
||||
for p, state, param_names in tuples:
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
@ -410,7 +407,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
from tuples, we still pass it to save some time.
|
||||
"""
|
||||
all_sumsq_orig = {}
|
||||
for (p, state, batch_param_names) in tuples:
|
||||
for p, state, batch_param_names in tuples:
|
||||
# p is a stacked batch parameters.
|
||||
batch_grad = p.grad
|
||||
if p.numel() == p.shape[0]: # a batch of scalars
|
||||
@ -426,7 +423,6 @@ class ScaledAdam(BatchedOptimizer):
|
||||
for name, sumsq_orig, rms, grad in zip(
|
||||
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
|
||||
):
|
||||
|
||||
proportion_orig = sumsq_orig / tot_sumsq
|
||||
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
||||
|
||||
@ -1039,7 +1035,7 @@ def _test_scaled_adam(hidden_dim: int):
|
||||
|
||||
# if epoch == 130:
|
||||
# opts = diagnostics.TensorDiagnosticOptions(
|
||||
# 2 ** 22
|
||||
# 512
|
||||
# ) # allow 4 megabytes per sub-module
|
||||
# diagnostic = diagnostics.attach_diagnostics(m, opts)
|
||||
|
||||
|
@ -74,6 +74,7 @@ def test_conv2d_subsampling():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -128,6 +129,7 @@ def test_rel_pos():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -204,6 +206,7 @@ def test_zipformer_encoder_layer():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -284,6 +287,7 @@ def test_zipformer_encoder():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -338,6 +342,7 @@ def test_zipformer():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
|
@ -1028,7 +1028,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -1052,7 +1052,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -326,41 +326,49 @@ def main():
|
||||
encoder = ort.InferenceSession(
|
||||
args.encoder_model_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder = ort.InferenceSession(
|
||||
args.decoder_model_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner = ort.InferenceSession(
|
||||
args.joiner_model_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_encoder_proj = ort.InferenceSession(
|
||||
args.joiner_encoder_proj_model_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_decoder_proj = ort.InferenceSession(
|
||||
args.joiner_decoder_proj_model_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
lconv = ort.InferenceSession(
|
||||
args.lconv_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
frame_reducer = ort.InferenceSession(
|
||||
args.frame_reducer_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
ctc_output = ort.InferenceSession(
|
||||
args.ctc_output_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
|
@ -1042,7 +1042,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -413,6 +413,7 @@ def export_decoder_model_onnx(
|
||||
context_size = decoder_model.decoder.context_size
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -401,6 +401,7 @@ def export_decoder_model_onnx(
|
||||
context_size = decoder_model.decoder.context_size
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -130,6 +130,7 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
@ -229,6 +230,7 @@ class OnnxModel:
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -242,6 +244,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -1029,7 +1029,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -1030,7 +1030,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -865,7 +865,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
return final_dropout_rate
|
||||
else:
|
||||
return initial_dropout_rate - (
|
||||
initial_dropout_rate * final_dropout_rate
|
||||
initial_dropout_rate - final_dropout_rate
|
||||
) * (self.batch_count / warmup_period)
|
||||
|
||||
def forward(
|
||||
|
@ -1141,7 +1141,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -1154,7 +1154,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -230,7 +230,9 @@ class Conformer(Transformer):
|
||||
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
|
||||
) # (T, B, F)
|
||||
else:
|
||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
|
||||
x = self.encoder(
|
||||
x, pos_emb, src_key_padding_mask=src_key_padding_mask
|
||||
) # (T, B, F)
|
||||
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
|
@ -543,7 +543,6 @@ def train_one_epoch(
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
|
||||
if tb_writer is not None:
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
|
@ -463,7 +463,6 @@ def train_one_epoch(
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}"
|
||||
)
|
||||
if batch_idx % params.log_interval == 0:
|
||||
|
||||
if tb_writer is not None:
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
|
@ -513,7 +513,6 @@ def train_one_epoch(
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
|
||||
if tb_writer is not None:
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user