mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Support exporting to ONNX for the wenetspeech recipe (#615)
* Support exporting to ONNX for the wenetspeech recipe
This commit is contained in:
parent
aa58c2ee02
commit
c39cba5191
124
.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh
vendored
Executable file
124
.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh
vendored
Executable file
@ -0,0 +1,124 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
cd egs/wenetspeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2
|
||||||
|
|
||||||
|
log "Downloading pre-trained model from $repo_url"
|
||||||
|
git lfs install
|
||||||
|
git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
|
||||||
|
log "Display test files"
|
||||||
|
tree $repo/
|
||||||
|
soxi $repo/test_wavs/*.wav
|
||||||
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
|
pushd $repo/exp
|
||||||
|
ln -s pretrained_epoch_10_avg_2.pt pretrained.pt
|
||||||
|
ln -s pretrained_epoch_10_avg_2.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/export.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--lang-dir $repo/data/lang_char \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--onnx 1
|
||||||
|
|
||||||
|
log "Export to torchscript model"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/export.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--lang-dir $repo/data/lang_char \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--jit 1
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/export.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--lang-dir $repo/data/lang_char \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--jit-trace 1
|
||||||
|
|
||||||
|
ls -lh $repo/exp/*.onnx
|
||||||
|
ls -lh $repo/exp/*.pt
|
||||||
|
|
||||||
|
log "Decode with ONNX models"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/onnx_check.py \
|
||||||
|
--jit-filename $repo/exp/cpu_jit.pt \
|
||||||
|
--onnx-encoder-filename $repo/exp/encoder.onnx \
|
||||||
|
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
||||||
|
--onnx-joiner-filename $repo/exp/joiner.onnx \
|
||||||
|
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \
|
||||||
|
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/onnx_pretrained.py \
|
||||||
|
--tokens $repo/data/lang_char/tokens.txt \
|
||||||
|
--encoder-model-filename $repo/exp/encoder.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner.onnx \
|
||||||
|
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
|
||||||
|
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
|
||||||
|
$repo/test_wavs/DEV_T0000000000.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000001.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000002.wav
|
||||||
|
|
||||||
|
log "Decode with models exported by torch.jit.trace()"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/jit_pretrained.py \
|
||||||
|
--tokens $repo/data/lang_char/tokens.txt \
|
||||||
|
--encoder-model-filename $repo/exp/encoder_jit_trace.pt \
|
||||||
|
--decoder-model-filename $repo/exp/decoder_jit_trace.pt \
|
||||||
|
--joiner-model-filename $repo/exp/joiner_jit_trace.pt \
|
||||||
|
$repo/test_wavs/DEV_T0000000000.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000001.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000002.wav
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/jit_pretrained.py \
|
||||||
|
--tokens $repo/data/lang_char/tokens.txt \
|
||||||
|
--encoder-model-filename $repo/exp/encoder_jit_script.pt \
|
||||||
|
--decoder-model-filename $repo/exp/decoder_jit_script.pt \
|
||||||
|
--joiner-model-filename $repo/exp/joiner_jit_script.pt \
|
||||||
|
$repo/test_wavs/DEV_T0000000000.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000001.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000002.wav
|
||||||
|
|
||||||
|
for sym in 1 2 3; do
|
||||||
|
log "Greedy search with --max-sym-per-frame $sym"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/pretrained.py \
|
||||||
|
--checkpoint $repo/exp/epoch-99.pt \
|
||||||
|
--lang-dir $repo/data/lang_char \
|
||||||
|
--decoding-method greedy_search \
|
||||||
|
--max-sym-per-frame $sym \
|
||||||
|
$repo/test_wavs/DEV_T0000000000.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000001.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000002.wav
|
||||||
|
done
|
||||||
|
|
||||||
|
for method in modified_beam_search beam_search fast_beam_search; do
|
||||||
|
log "$method"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/pretrained.py \
|
||||||
|
--decoding-method $method \
|
||||||
|
--beam-size 4 \
|
||||||
|
--checkpoint $repo/exp/epoch-99.pt \
|
||||||
|
--lang-dir $repo/data/lang_char \
|
||||||
|
$repo/test_wavs/DEV_T0000000000.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000001.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000002.wav
|
||||||
|
done
|
||||||
80
.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml
vendored
Normal file
80
.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml
vendored
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
|
||||||
|
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
name: run-wenetspeech-pruned-transducer-stateless2
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
pull_request:
|
||||||
|
types: [labeled]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
run_librispeech_pruned_transducer_stateless3_2022_05_13:
|
||||||
|
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'wenetspeech'
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-18.04]
|
||||||
|
python-version: [3.8]
|
||||||
|
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Setup Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
cache: 'pip'
|
||||||
|
cache-dependency-path: '**/requirements-ci.txt'
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
run: |
|
||||||
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
|
pip uninstall -y protobuf
|
||||||
|
pip install --no-binary protobuf protobuf
|
||||||
|
|
||||||
|
- name: Cache kaldifeat
|
||||||
|
id: my-cache
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/tmp/kaldifeat
|
||||||
|
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||||
|
|
||||||
|
- name: Install kaldifeat
|
||||||
|
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/install-kaldifeat.sh
|
||||||
|
|
||||||
|
- name: Inference with pre-trained model
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||||
|
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||||
|
run: |
|
||||||
|
sudo apt-get -qq install git-lfs tree sox
|
||||||
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|
||||||
|
.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh
|
||||||
@ -73,6 +73,10 @@ Check `onnx_check.py` for how to use them.
|
|||||||
|
|
||||||
Please see ./onnx_pretrained.py for usage of the generated files
|
Please see ./onnx_pretrained.py for usage of the generated files
|
||||||
|
|
||||||
|
Check
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx
|
||||||
|
for how to use the exported models outside of icefall.
|
||||||
|
|
||||||
(4) Export `model.state_dict()`
|
(4) Export `model.state_dict()`
|
||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
|
|||||||
@ -31,6 +31,8 @@ Usage of this script:
|
|||||||
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \
|
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \
|
||||||
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \
|
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \
|
||||||
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \
|
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \
|
||||||
|
--joiner-encoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_encoder_proj.onnx \
|
||||||
|
--joiner-decoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_decoder_proj.onnx \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
|
|||||||
484
egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
Normal file → Executable file
484
egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
Normal file → Executable file
@ -1,3 +1,4 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
@ -18,6 +19,64 @@
|
|||||||
# to a single one using model averaging.
|
# to a single one using model averaging.
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
|
(1) Export to torchscript model using torch.jit.script()
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--epoch 10 \
|
||||||
|
--avg 2 \
|
||||||
|
--jit 1
|
||||||
|
|
||||||
|
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
||||||
|
load it by `torch.jit.load("cpu_jit.pt")`.
|
||||||
|
|
||||||
|
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
||||||
|
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
||||||
|
|
||||||
|
Please refer to
|
||||||
|
https://k2-fsa.github.io/sherpa/python/offline_asr/conformer/index.html
|
||||||
|
for how to use `cpu_jit.pt` for speech recognition.
|
||||||
|
|
||||||
|
It will also generate 3 other files: `encoder_jit_script.pt`,
|
||||||
|
`decoder_jit_script.pt`, and `joiner_jit_script.pt`. Check ./jit_pretrained.py
|
||||||
|
for how to use them.
|
||||||
|
|
||||||
|
(2) Export to torchscript model using torch.jit.trace()
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--epoch 10 \
|
||||||
|
--avg 2 \
|
||||||
|
--jit-trace 1
|
||||||
|
|
||||||
|
It will generate the following files:
|
||||||
|
- encoder_jit_trace.pt
|
||||||
|
- decoder_jit_trace.pt
|
||||||
|
- joiner_jit_trace.pt
|
||||||
|
|
||||||
|
Check ./jit_pretrained.py for usage.
|
||||||
|
|
||||||
|
(3) Export to ONNX format
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--epoch 10 \
|
||||||
|
--avg 2 \
|
||||||
|
--onnx 1
|
||||||
|
|
||||||
|
Refer to ./onnx_check.py and ./onnx_pretrained.py
|
||||||
|
for usage.
|
||||||
|
|
||||||
|
Check
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx
|
||||||
|
for how to use the exported models outside of icefall.
|
||||||
|
|
||||||
|
(4) Export `model.state_dict()`
|
||||||
|
|
||||||
./pruned_transducer_stateless2/export.py \
|
./pruned_transducer_stateless2/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--lang-dir data/lang_char \
|
--lang-dir data/lang_char \
|
||||||
@ -35,10 +94,13 @@ you can do:
|
|||||||
cd /path/to/egs/wenetspeech/ASR
|
cd /path/to/egs/wenetspeech/ASR
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--epoch 10 \
|
--epoch 9999 \
|
||||||
--avg 2 \
|
--avg 1 \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--lang-dir data/lang_char
|
--lang-dir data/lang_char
|
||||||
|
|
||||||
|
You can find pretrained models at
|
||||||
|
https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/tree/main/exp
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -46,6 +108,8 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
@ -96,6 +160,44 @@ def get_parser():
|
|||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="""True to save a model after applying torch.jit.script.
|
help="""True to save a model after applying torch.jit.script.
|
||||||
|
It will generate 4 files:
|
||||||
|
- encoder_jit_script.pt
|
||||||
|
- decoder_jit_script.pt
|
||||||
|
- joiner_jit_script.pt
|
||||||
|
- cpu_jit.pt (which combines the above 3 files)
|
||||||
|
|
||||||
|
Check ./jit_pretrained.py for how to use xxx_jit_script.pt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--jit-trace",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True to save a model after applying torch.jit.trace.
|
||||||
|
It will generate 3 files:
|
||||||
|
- encoder_jit_trace.pt
|
||||||
|
- decoder_jit_trace.pt
|
||||||
|
- joiner_jit_trace.pt
|
||||||
|
|
||||||
|
Check ./jit_pretrained.py for how to use them.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--onnx",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""If True, --jit is ignored and it exports the model
|
||||||
|
to onnx format. It will generate the following files:
|
||||||
|
|
||||||
|
- encoder.onnx
|
||||||
|
- decoder.onnx
|
||||||
|
- joiner.onnx
|
||||||
|
- joiner_encoder_proj.onnx
|
||||||
|
- joiner_decoder_proj.onnx
|
||||||
|
|
||||||
|
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -110,6 +212,332 @@ def get_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_jit_script(
|
||||||
|
encoder_model: nn.Module,
|
||||||
|
encoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model with torch.jit.script()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
script_model = torch.jit.script(encoder_model)
|
||||||
|
script_model.save(encoder_filename)
|
||||||
|
logging.info(f"Saved to {encoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_jit_script(
|
||||||
|
decoder_model: nn.Module,
|
||||||
|
decoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given decoder model with torch.jit.script()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The input decoder model
|
||||||
|
decoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
script_model = torch.jit.script(decoder_model)
|
||||||
|
script_model.save(decoder_filename)
|
||||||
|
logging.info(f"Saved to {decoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_jit_script(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given joiner model with torch.jit.trace()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
joiner_model:
|
||||||
|
The input joiner model
|
||||||
|
joiner_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
script_model = torch.jit.script(joiner_model)
|
||||||
|
script_model.save(joiner_filename)
|
||||||
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_jit_trace(
|
||||||
|
encoder_model: nn.Module,
|
||||||
|
encoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The warmup argument is fixed to 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(encoder_model, (x, x_lens))
|
||||||
|
traced_model.save(encoder_filename)
|
||||||
|
logging.info(f"Saved to {encoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_jit_trace(
|
||||||
|
decoder_model: nn.Module,
|
||||||
|
decoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given decoder model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The argument need_pad is fixed to False.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The input decoder model
|
||||||
|
decoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||||
|
need_pad = torch.tensor([False])
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||||
|
traced_model.save(decoder_filename)
|
||||||
|
logging.info(f"Saved to {decoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_jit_trace(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given joiner model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The argument project_input is fixed to True. A user should not
|
||||||
|
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||||
|
will do that for the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
joiner_model:
|
||||||
|
The input joiner model
|
||||||
|
joiner_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||||
|
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||||
|
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||||
|
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
||||||
|
traced_model.save(joiner_filename)
|
||||||
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_onnx(
|
||||||
|
encoder_model: nn.Module,
|
||||||
|
encoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model to ONNX format.
|
||||||
|
The exported model has two inputs:
|
||||||
|
|
||||||
|
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||||
|
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||||
|
|
||||||
|
and it has two outputs:
|
||||||
|
|
||||||
|
- encoder_out, a tensor of shape (N, T, C)
|
||||||
|
- encoder_out_lens, a tensor of shape (N,)
|
||||||
|
|
||||||
|
Note: The warmup argument is fixed to 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||||
|
|
||||||
|
# encoder_model = torch.jit.script(encoder_model)
|
||||||
|
# It throws the following error for the above statement
|
||||||
|
#
|
||||||
|
# RuntimeError: Exporting the operator __is_ to ONNX opset version
|
||||||
|
# 11 is not supported. Please feel free to request support or
|
||||||
|
# submit a pull request on PyTorch GitHub.
|
||||||
|
#
|
||||||
|
# I cannot find which statement causes the above error.
|
||||||
|
# torch.onnx.export() will use torch.jit.trace() internally, which
|
||||||
|
# works well for the current reworked model
|
||||||
|
warmup = 1.0
|
||||||
|
torch.onnx.export(
|
||||||
|
encoder_model,
|
||||||
|
(x, x_lens, warmup),
|
||||||
|
encoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["x", "x_lens", "warmup"],
|
||||||
|
output_names=["encoder_out", "encoder_out_lens"],
|
||||||
|
dynamic_axes={
|
||||||
|
"x": {0: "N", 1: "T"},
|
||||||
|
"x_lens": {0: "N"},
|
||||||
|
"encoder_out": {0: "N", 1: "T"},
|
||||||
|
"encoder_out_lens": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {encoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_onnx(
|
||||||
|
decoder_model: nn.Module,
|
||||||
|
decoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the decoder model to ONNX format.
|
||||||
|
|
||||||
|
The exported model has one input:
|
||||||
|
|
||||||
|
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||||
|
|
||||||
|
and has one output:
|
||||||
|
|
||||||
|
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||||
|
|
||||||
|
Note: The argument need_pad is fixed to False.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The decoder model to be exported.
|
||||||
|
decoder_filename:
|
||||||
|
Filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||||
|
need_pad = False # Always False, so we can use torch.jit.trace() here
|
||||||
|
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
|
||||||
|
# in this case
|
||||||
|
torch.onnx.export(
|
||||||
|
decoder_model,
|
||||||
|
(y, need_pad),
|
||||||
|
decoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["y", "need_pad"],
|
||||||
|
output_names=["decoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"y": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {decoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_onnx(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the joiner model to ONNX format.
|
||||||
|
The exported joiner model has two inputs:
|
||||||
|
|
||||||
|
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
and produces one output:
|
||||||
|
|
||||||
|
- logit: a tensor of shape (N, vocab_size)
|
||||||
|
|
||||||
|
The exported encoder_proj model has one input:
|
||||||
|
|
||||||
|
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
||||||
|
|
||||||
|
and produces one output:
|
||||||
|
|
||||||
|
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
The exported decoder_proj model has one input:
|
||||||
|
|
||||||
|
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
||||||
|
|
||||||
|
and produces one output:
|
||||||
|
|
||||||
|
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
"""
|
||||||
|
encoder_proj_filename = str(joiner_filename).replace(
|
||||||
|
".onnx", "_encoder_proj.onnx"
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_proj_filename = str(joiner_filename).replace(
|
||||||
|
".onnx", "_decoder_proj.onnx"
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||||
|
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||||
|
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
|
||||||
|
|
||||||
|
projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
|
||||||
|
projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
project_input = False
|
||||||
|
# Note: It uses torch.jit.trace() internally
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model,
|
||||||
|
(projected_encoder_out, projected_decoder_out, project_input),
|
||||||
|
joiner_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=[
|
||||||
|
"projected_encoder_out",
|
||||||
|
"projected_decoder_out",
|
||||||
|
"project_input",
|
||||||
|
],
|
||||||
|
output_names=["logit"],
|
||||||
|
dynamic_axes={
|
||||||
|
"projected_encoder_out": {0: "N"},
|
||||||
|
"projected_decoder_out": {0: "N"},
|
||||||
|
"logit": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model.encoder_proj,
|
||||||
|
encoder_out,
|
||||||
|
encoder_proj_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["encoder_out"],
|
||||||
|
output_names=["projected_encoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
"projected_encoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {encoder_proj_filename}")
|
||||||
|
|
||||||
|
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model.decoder_proj,
|
||||||
|
decoder_out,
|
||||||
|
decoder_proj_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["decoder_out"],
|
||||||
|
output_names=["projected_decoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
"projected_decoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {decoder_proj_filename}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
@ -147,22 +575,66 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
model.to("cpu")
|
model.to("cpu")
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.onnx is True:
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
opset_version = 11
|
||||||
|
logging.info("Exporting to onnx format")
|
||||||
|
encoder_filename = params.exp_dir / "encoder.onnx"
|
||||||
|
export_encoder_model_onnx(
|
||||||
|
model.encoder,
|
||||||
|
encoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||||
|
export_decoder_model_onnx(
|
||||||
|
model.decoder,
|
||||||
|
decoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||||
|
export_joiner_model_onnx(
|
||||||
|
model.joiner,
|
||||||
|
joiner_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
elif params.jit:
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
logging.info("Using torch.jit.script")
|
||||||
# We won't use the forward() method of the model in C++, so just ignore
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
# it here.
|
# it here.
|
||||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
# torch scriptabe.
|
# torch scriptabe.
|
||||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
model.save(str(filename))
|
model.save(str(filename))
|
||||||
logging.info(f"Saved to {filename}")
|
logging.info(f"Saved to {filename}")
|
||||||
|
|
||||||
|
# Also export encoder/decoder/joiner separately
|
||||||
|
encoder_filename = params.exp_dir / "encoder_jit_script.pt"
|
||||||
|
export_encoder_model_jit_script(model.encoder, encoder_filename)
|
||||||
|
|
||||||
|
decoder_filename = params.exp_dir / "decoder_jit_script.pt"
|
||||||
|
export_decoder_model_jit_script(model.decoder, decoder_filename)
|
||||||
|
|
||||||
|
joiner_filename = params.exp_dir / "joiner_jit_script.pt"
|
||||||
|
export_joiner_model_jit_script(model.joiner, joiner_filename)
|
||||||
|
elif params.jit_trace is True:
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
logging.info("Using torch.jit.trace()")
|
||||||
|
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
||||||
|
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||||
|
|
||||||
|
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
|
||||||
|
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||||
|
|
||||||
|
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
|
||||||
|
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||||
else:
|
else:
|
||||||
logging.info("Not using torch.jit.script")
|
logging.info("Not using torch.jit.script")
|
||||||
# Save it using a format so that it can be loaded
|
# Save it using a format so that it can be loaded
|
||||||
|
|||||||
339
egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
Executable file
339
egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
Executable file
@ -0,0 +1,339 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads torchscript models, either exported by `torch.jit.trace()`
|
||||||
|
or by `torch.jit.script()`, and uses them to decode waves.
|
||||||
|
You can use the following command to get the exported models:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--tokens data/lang_char/tokens.txt \
|
||||||
|
--epoch 10 \
|
||||||
|
--avg 2 \
|
||||||
|
--jit-trace 1
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--tokens data/lang_char/tokens.txt \
|
||||||
|
--epoch 10 \
|
||||||
|
--avg 2 \
|
||||||
|
--jit 1
|
||||||
|
|
||||||
|
Usage of this script:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/jit_pretrained.py \
|
||||||
|
--encoder-model-filename ./pruned_transducer_stateless2/exp/encoder_jit_trace.pt \
|
||||||
|
--decoder-model-filename ./pruned_transducer_stateless2/exp/decoder_jit_trace.pt \
|
||||||
|
--joiner-model-filename ./pruned_transducer_stateless2/exp/joiner_jit_trace.pt \
|
||||||
|
--tokens data/lang_char/tokens.txt \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/jit_pretrained.py \
|
||||||
|
--encoder-model-filename ./pruned_transducer_stateless2/exp/encoder_jit_script.pt \
|
||||||
|
--decoder-model-filename ./pruned_transducer_stateless2/exp/decoder_jit_script.pt \
|
||||||
|
--joiner-model-filename ./pruned_transducer_stateless2/exp/joiner_jit_script.pt \
|
||||||
|
--tokens data/lang_char/tokens.txt \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
You can find pretrained models at
|
||||||
|
https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/tree/main/exp
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import kaldifeat
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the encoder torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the decoder torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the joiner torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
help="""Path to tokens.txt""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="The sample rate of the input sound file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Context size of the decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert sample_rate == expected_sample_rate, (
|
||||||
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search(
|
||||||
|
decoder: torch.jit.ScriptModule,
|
||||||
|
joiner: torch.jit.ScriptModule,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
context_size: int,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
|
Args:
|
||||||
|
decoder:
|
||||||
|
The decoder model.
|
||||||
|
joiner:
|
||||||
|
The joiner model.
|
||||||
|
encoder_out:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,).
|
||||||
|
context_size:
|
||||||
|
The context size of the decoder model.
|
||||||
|
Returns:
|
||||||
|
Return the decoded results for each utterance.
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
device = encoder_out.device
|
||||||
|
blank_id = 0 # hard-code to 0
|
||||||
|
|
||||||
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
hyps,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (N, context_size)
|
||||||
|
|
||||||
|
decoder_out = decoder(
|
||||||
|
decoder_input,
|
||||||
|
need_pad=torch.tensor([False]),
|
||||||
|
).squeeze(1)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for batch_size in batch_size_list:
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = packed_encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out
|
||||||
|
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
decoder_out = decoder_out[:batch_size]
|
||||||
|
|
||||||
|
logits = joiner(
|
||||||
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
)
|
||||||
|
# logits'shape (batch_size, vocab_size)
|
||||||
|
|
||||||
|
assert logits.ndim == 2, logits.shape
|
||||||
|
y = logits.argmax(dim=1).tolist()
|
||||||
|
emitted = False
|
||||||
|
for i, v in enumerate(y):
|
||||||
|
if v != blank_id:
|
||||||
|
hyps[i].append(v)
|
||||||
|
emitted = True
|
||||||
|
if emitted:
|
||||||
|
# update decoder output
|
||||||
|
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
decoder_input,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
decoder_out = decoder(
|
||||||
|
decoder_input,
|
||||||
|
need_pad=torch.tensor([False]),
|
||||||
|
)
|
||||||
|
decoder_out = decoder_out.squeeze(1)
|
||||||
|
|
||||||
|
sorted_ans = [h[context_size:] for h in hyps]
|
||||||
|
ans = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
encoder = torch.jit.load(args.encoder_model_filename)
|
||||||
|
decoder = torch.jit.load(args.decoder_model_filename)
|
||||||
|
joiner = torch.jit.load(args.joiner_model_filename)
|
||||||
|
|
||||||
|
encoder.eval()
|
||||||
|
decoder.eval()
|
||||||
|
joiner.eval()
|
||||||
|
|
||||||
|
encoder.to(device)
|
||||||
|
decoder.to(device)
|
||||||
|
joiner.to(device)
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = device
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = args.sample_rate
|
||||||
|
opts.mel_opts.num_bins = 80
|
||||||
|
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {args.sound_files}")
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=args.sound_files,
|
||||||
|
expected_sample_rate=args.sample_rate,
|
||||||
|
)
|
||||||
|
waves = [w.to(device) for w in waves]
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
features = fbank(waves)
|
||||||
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
|
features = pad_sequence(
|
||||||
|
features,
|
||||||
|
batch_first=True,
|
||||||
|
padding_value=math.log(1e-10),
|
||||||
|
)
|
||||||
|
|
||||||
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = encoder(
|
||||||
|
x=features,
|
||||||
|
x_lens=feature_lengths,
|
||||||
|
)
|
||||||
|
|
||||||
|
hyps = greedy_search(
|
||||||
|
decoder=decoder,
|
||||||
|
joiner=joiner,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
context_size=args.context_size,
|
||||||
|
)
|
||||||
|
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||||
|
s = "\n"
|
||||||
|
for filename, hyp in zip(args.sound_files, hyps):
|
||||||
|
words = "".join([symbol_table[i] for i in hyp])
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
||||||
307
egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
Executable file
307
egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
Executable file
@ -0,0 +1,307 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script checks that exported onnx models produce the same output
|
||||||
|
with the given torchscript model for the same input.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/onnx_check.py \
|
||||||
|
--jit-filename ./t/cpu_jit.pt \
|
||||||
|
--onnx-encoder-filename ./t/encoder.onnx \
|
||||||
|
--onnx-decoder-filename ./t/decoder.onnx \
|
||||||
|
--onnx-joiner-filename ./t/joiner.onnx \
|
||||||
|
--onnx-joiner-encoder-proj-filename ./t/joiner_encoder_proj.onnx \
|
||||||
|
--onnx-joiner-decoder-proj-filename ./t/joiner_decoder_proj.onnx
|
||||||
|
|
||||||
|
You can generate cpu_jit.pt, encoder.onnx, decoder.onnx, and other
|
||||||
|
xxx.onnx files using ./export.py
|
||||||
|
|
||||||
|
We provide pretrained models at:
|
||||||
|
https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/tree/main/exp
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import onnxruntime as ort
|
||||||
|
import torch
|
||||||
|
|
||||||
|
ort.set_default_logger_severity(3)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--jit-filename",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Path to the torchscript model exported by torch.jit.script",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--onnx-encoder-filename",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Path to the onnx encoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--onnx-decoder-filename",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Path to the onnx decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--onnx-joiner-filename",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Path to the onnx joiner model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--onnx-joiner-encoder-proj-filename",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Path to the onnx joiner encoder projection model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--onnx-joiner-decoder-proj-filename",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Path to the onnx joiner decoder projection model",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder(
|
||||||
|
model: torch.jit.ScriptModule,
|
||||||
|
encoder_session: ort.InferenceSession,
|
||||||
|
):
|
||||||
|
inputs = encoder_session.get_inputs()
|
||||||
|
outputs = encoder_session.get_outputs()
|
||||||
|
input_names = [n.name for n in inputs]
|
||||||
|
output_names = [n.name for n in outputs]
|
||||||
|
|
||||||
|
assert inputs[0].shape == ["N", "T", 80]
|
||||||
|
assert inputs[1].shape == ["N"]
|
||||||
|
|
||||||
|
for N in [1, 5]:
|
||||||
|
for T in [12, 25]:
|
||||||
|
print("N, T", N, T)
|
||||||
|
x = torch.rand(N, T, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.randint(low=10, high=T + 1, size=(N,))
|
||||||
|
x_lens[0] = T
|
||||||
|
|
||||||
|
encoder_inputs = {
|
||||||
|
input_names[0]: x.numpy(),
|
||||||
|
input_names[1]: x_lens.numpy(),
|
||||||
|
}
|
||||||
|
encoder_out, encoder_out_lens = encoder_session.run(
|
||||||
|
output_names,
|
||||||
|
encoder_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
|
||||||
|
|
||||||
|
encoder_out = torch.from_numpy(encoder_out)
|
||||||
|
assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
|
||||||
|
(encoder_out - torch_encoder_out).abs().max(),
|
||||||
|
encoder_out.shape,
|
||||||
|
torch_encoder_out.shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_decoder(
|
||||||
|
model: torch.jit.ScriptModule,
|
||||||
|
decoder_session: ort.InferenceSession,
|
||||||
|
):
|
||||||
|
inputs = decoder_session.get_inputs()
|
||||||
|
outputs = decoder_session.get_outputs()
|
||||||
|
input_names = [n.name for n in inputs]
|
||||||
|
output_names = [n.name for n in outputs]
|
||||||
|
|
||||||
|
assert inputs[0].shape == ["N", 2]
|
||||||
|
for N in [1, 5, 10]:
|
||||||
|
y = torch.randint(low=1, high=500, size=(10, 2))
|
||||||
|
|
||||||
|
decoder_inputs = {input_names[0]: y.numpy()}
|
||||||
|
decoder_out = decoder_session.run(
|
||||||
|
output_names,
|
||||||
|
decoder_inputs,
|
||||||
|
)[0]
|
||||||
|
decoder_out = torch.from_numpy(decoder_out)
|
||||||
|
|
||||||
|
torch_decoder_out = model.decoder(y, need_pad=False)
|
||||||
|
assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), (
|
||||||
|
(decoder_out - torch_decoder_out).abs().max()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_joiner(
|
||||||
|
model: torch.jit.ScriptModule,
|
||||||
|
joiner_session: ort.InferenceSession,
|
||||||
|
joiner_encoder_proj_session: ort.InferenceSession,
|
||||||
|
joiner_decoder_proj_session: ort.InferenceSession,
|
||||||
|
):
|
||||||
|
joiner_inputs = joiner_session.get_inputs()
|
||||||
|
joiner_outputs = joiner_session.get_outputs()
|
||||||
|
joiner_input_names = [n.name for n in joiner_inputs]
|
||||||
|
joiner_output_names = [n.name for n in joiner_outputs]
|
||||||
|
|
||||||
|
assert joiner_inputs[0].shape == ["N", 512]
|
||||||
|
assert joiner_inputs[1].shape == ["N", 512]
|
||||||
|
|
||||||
|
joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs()
|
||||||
|
encoder_proj_input_name = joiner_encoder_proj_inputs[0].name
|
||||||
|
|
||||||
|
assert joiner_encoder_proj_inputs[0].shape == ["N", 512]
|
||||||
|
|
||||||
|
joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs()
|
||||||
|
encoder_proj_output_name = joiner_encoder_proj_outputs[0].name
|
||||||
|
|
||||||
|
joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs()
|
||||||
|
decoder_proj_input_name = joiner_decoder_proj_inputs[0].name
|
||||||
|
|
||||||
|
assert joiner_decoder_proj_inputs[0].shape == ["N", 512]
|
||||||
|
|
||||||
|
joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs()
|
||||||
|
decoder_proj_output_name = joiner_decoder_proj_outputs[0].name
|
||||||
|
|
||||||
|
for N in [1, 5, 10]:
|
||||||
|
encoder_out = torch.rand(N, 512)
|
||||||
|
decoder_out = torch.rand(N, 512)
|
||||||
|
|
||||||
|
projected_encoder_out = torch.rand(N, 512)
|
||||||
|
projected_decoder_out = torch.rand(N, 512)
|
||||||
|
|
||||||
|
joiner_inputs = {
|
||||||
|
joiner_input_names[0]: projected_encoder_out.numpy(),
|
||||||
|
joiner_input_names[1]: projected_decoder_out.numpy(),
|
||||||
|
}
|
||||||
|
joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0]
|
||||||
|
joiner_out = torch.from_numpy(joiner_out)
|
||||||
|
|
||||||
|
torch_joiner_out = model.joiner(
|
||||||
|
projected_encoder_out,
|
||||||
|
projected_decoder_out,
|
||||||
|
project_input=False,
|
||||||
|
)
|
||||||
|
assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), (
|
||||||
|
(joiner_out - torch_joiner_out).abs().max()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now test encoder_proj
|
||||||
|
joiner_encoder_proj_inputs = {
|
||||||
|
encoder_proj_input_name: encoder_out.numpy()
|
||||||
|
}
|
||||||
|
joiner_encoder_proj_out = joiner_encoder_proj_session.run(
|
||||||
|
[encoder_proj_output_name], joiner_encoder_proj_inputs
|
||||||
|
)[0]
|
||||||
|
joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out)
|
||||||
|
|
||||||
|
torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out)
|
||||||
|
assert torch.allclose(
|
||||||
|
joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
|
||||||
|
), (
|
||||||
|
(joiner_encoder_proj_out - torch_joiner_encoder_proj_out)
|
||||||
|
.abs()
|
||||||
|
.max()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now test decoder_proj
|
||||||
|
joiner_decoder_proj_inputs = {
|
||||||
|
decoder_proj_input_name: decoder_out.numpy()
|
||||||
|
}
|
||||||
|
joiner_decoder_proj_out = joiner_decoder_proj_session.run(
|
||||||
|
[decoder_proj_output_name], joiner_decoder_proj_inputs
|
||||||
|
)[0]
|
||||||
|
joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out)
|
||||||
|
|
||||||
|
torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
assert torch.allclose(
|
||||||
|
joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
|
||||||
|
), (
|
||||||
|
(joiner_decoder_proj_out - torch_joiner_decoder_proj_out)
|
||||||
|
.abs()
|
||||||
|
.max()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
model = torch.jit.load(args.jit_filename)
|
||||||
|
|
||||||
|
options = ort.SessionOptions()
|
||||||
|
options.inter_op_num_threads = 1
|
||||||
|
options.intra_op_num_threads = 1
|
||||||
|
|
||||||
|
logging.info("Test encoder")
|
||||||
|
encoder_session = ort.InferenceSession(
|
||||||
|
args.onnx_encoder_filename,
|
||||||
|
sess_options=options,
|
||||||
|
)
|
||||||
|
test_encoder(model, encoder_session)
|
||||||
|
|
||||||
|
logging.info("Test decoder")
|
||||||
|
decoder_session = ort.InferenceSession(
|
||||||
|
args.onnx_decoder_filename,
|
||||||
|
sess_options=options,
|
||||||
|
)
|
||||||
|
test_decoder(model, decoder_session)
|
||||||
|
|
||||||
|
logging.info("Test joiner")
|
||||||
|
joiner_session = ort.InferenceSession(
|
||||||
|
args.onnx_joiner_filename,
|
||||||
|
sess_options=options,
|
||||||
|
)
|
||||||
|
joiner_encoder_proj_session = ort.InferenceSession(
|
||||||
|
args.onnx_joiner_encoder_proj_filename,
|
||||||
|
sess_options=options,
|
||||||
|
)
|
||||||
|
joiner_decoder_proj_session = ort.InferenceSession(
|
||||||
|
args.onnx_joiner_decoder_proj_filename,
|
||||||
|
sess_options=options,
|
||||||
|
)
|
||||||
|
test_joiner(
|
||||||
|
model,
|
||||||
|
joiner_session,
|
||||||
|
joiner_encoder_proj_session,
|
||||||
|
joiner_decoder_proj_session,
|
||||||
|
)
|
||||||
|
logging.info("Finished checking ONNX models")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.manual_seed(20220727)
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
||||||
391
egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
Executable file
391
egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
Executable file
@ -0,0 +1,391 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads ONNX models and uses them to decode waves.
|
||||||
|
You can use the following command to get the exported models:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--onnx 1
|
||||||
|
|
||||||
|
Usage of this script:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/onnx_pretrained.py \
|
||||||
|
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \
|
||||||
|
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \
|
||||||
|
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \
|
||||||
|
--joiner-encoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_encoder_proj.onnx \
|
||||||
|
--joiner-decoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_decoder_proj.onnx \
|
||||||
|
--tokens data/lang_char/tokens.txt \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
We provide pretrained models at:
|
||||||
|
https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/tree/main/exp
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import kaldifeat
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the encoder onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the decoder onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the joiner onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-encoder-proj-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the joiner encoder_proj onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-decoder-proj-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the joiner decoder_proj onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
help="""Path to tokens.txt""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="The sample rate of the input sound file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Context size of the decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert sample_rate == expected_sample_rate, (
|
||||||
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search(
|
||||||
|
decoder: ort.InferenceSession,
|
||||||
|
joiner: ort.InferenceSession,
|
||||||
|
joiner_encoder_proj: ort.InferenceSession,
|
||||||
|
joiner_decoder_proj: ort.InferenceSession,
|
||||||
|
encoder_out: np.ndarray,
|
||||||
|
encoder_out_lens: np.ndarray,
|
||||||
|
context_size: int,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
|
Args:
|
||||||
|
decoder:
|
||||||
|
The decoder model.
|
||||||
|
joiner:
|
||||||
|
The joiner model.
|
||||||
|
joiner_encoder_proj:
|
||||||
|
The joiner encoder projection model.
|
||||||
|
joiner_decoder_proj:
|
||||||
|
The joiner decoder projection model.
|
||||||
|
encoder_out:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,).
|
||||||
|
context_size:
|
||||||
|
The context size of the decoder model.
|
||||||
|
Returns:
|
||||||
|
Return the decoded results for each utterance.
|
||||||
|
"""
|
||||||
|
encoder_out = torch.from_numpy(encoder_out)
|
||||||
|
encoder_out_lens = torch.from_numpy(encoder_out_lens)
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
projected_encoder_out = joiner_encoder_proj.run(
|
||||||
|
[joiner_encoder_proj.get_outputs()[0].name],
|
||||||
|
{
|
||||||
|
joiner_encoder_proj.get_inputs()[
|
||||||
|
0
|
||||||
|
].name: packed_encoder_out.data.numpy()
|
||||||
|
},
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
blank_id = 0 # hard-code to 0
|
||||||
|
|
||||||
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||||
|
|
||||||
|
decoder_input_nodes = decoder.get_inputs()
|
||||||
|
decoder_output_nodes = decoder.get_outputs()
|
||||||
|
|
||||||
|
joiner_input_nodes = joiner.get_inputs()
|
||||||
|
joiner_output_nodes = joiner.get_outputs()
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
hyps,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (N, context_size)
|
||||||
|
|
||||||
|
decoder_out = decoder.run(
|
||||||
|
[decoder_output_nodes[0].name],
|
||||||
|
{
|
||||||
|
decoder_input_nodes[0].name: decoder_input.numpy(),
|
||||||
|
},
|
||||||
|
)[0].squeeze(1)
|
||||||
|
projected_decoder_out = joiner_decoder_proj.run(
|
||||||
|
[joiner_decoder_proj.get_outputs()[0].name],
|
||||||
|
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
projected_decoder_out = torch.from_numpy(projected_decoder_out)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for batch_size in batch_size_list:
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = projected_encoder_out[start:end]
|
||||||
|
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
projected_decoder_out = projected_decoder_out[:batch_size]
|
||||||
|
|
||||||
|
logits = joiner.run(
|
||||||
|
[joiner_output_nodes[0].name],
|
||||||
|
{
|
||||||
|
joiner_input_nodes[0].name: current_encoder_out,
|
||||||
|
joiner_input_nodes[1].name: projected_decoder_out.numpy(),
|
||||||
|
},
|
||||||
|
)[0]
|
||||||
|
logits = torch.from_numpy(logits).squeeze(1).squeeze(1)
|
||||||
|
# logits'shape (batch_size, vocab_size)
|
||||||
|
|
||||||
|
assert logits.ndim == 2, logits.shape
|
||||||
|
y = logits.argmax(dim=1).tolist()
|
||||||
|
emitted = False
|
||||||
|
for i, v in enumerate(y):
|
||||||
|
if v != blank_id:
|
||||||
|
hyps[i].append(v)
|
||||||
|
emitted = True
|
||||||
|
if emitted:
|
||||||
|
# update decoder output
|
||||||
|
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
decoder_input,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
decoder_out = decoder.run(
|
||||||
|
[decoder_output_nodes[0].name],
|
||||||
|
{
|
||||||
|
decoder_input_nodes[0].name: decoder_input.numpy(),
|
||||||
|
},
|
||||||
|
)[0].squeeze(1)
|
||||||
|
projected_decoder_out = joiner_decoder_proj.run(
|
||||||
|
[joiner_decoder_proj.get_outputs()[0].name],
|
||||||
|
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
|
||||||
|
)[0]
|
||||||
|
projected_decoder_out = torch.from_numpy(projected_decoder_out)
|
||||||
|
|
||||||
|
sorted_ans = [h[context_size:] for h in hyps]
|
||||||
|
ans = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 1
|
||||||
|
|
||||||
|
encoder = ort.InferenceSession(
|
||||||
|
args.encoder_model_filename,
|
||||||
|
sess_options=session_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = ort.InferenceSession(
|
||||||
|
args.decoder_model_filename,
|
||||||
|
sess_options=session_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner = ort.InferenceSession(
|
||||||
|
args.joiner_model_filename,
|
||||||
|
sess_options=session_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_encoder_proj = ort.InferenceSession(
|
||||||
|
args.joiner_encoder_proj_model_filename,
|
||||||
|
sess_options=session_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_decoder_proj = ort.InferenceSession(
|
||||||
|
args.joiner_decoder_proj_model_filename,
|
||||||
|
sess_options=session_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = "cpu"
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = args.sample_rate
|
||||||
|
opts.mel_opts.num_bins = 80
|
||||||
|
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {args.sound_files}")
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=args.sound_files,
|
||||||
|
expected_sample_rate=args.sample_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
features = fbank(waves)
|
||||||
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
|
features = pad_sequence(
|
||||||
|
features,
|
||||||
|
batch_first=True,
|
||||||
|
padding_value=math.log(1e-10),
|
||||||
|
)
|
||||||
|
|
||||||
|
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
||||||
|
|
||||||
|
encoder_input_nodes = encoder.get_inputs()
|
||||||
|
encoder_out_nodes = encoder.get_outputs()
|
||||||
|
encoder_out, encoder_out_lens = encoder.run(
|
||||||
|
[encoder_out_nodes[0].name, encoder_out_nodes[1].name],
|
||||||
|
{
|
||||||
|
encoder_input_nodes[0].name: features.numpy(),
|
||||||
|
encoder_input_nodes[1].name: feature_lengths.numpy(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
hyps = greedy_search(
|
||||||
|
decoder=decoder,
|
||||||
|
joiner=joiner,
|
||||||
|
joiner_encoder_proj=joiner_encoder_proj,
|
||||||
|
joiner_decoder_proj=joiner_decoder_proj,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
context_size=args.context_size,
|
||||||
|
)
|
||||||
|
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||||
|
s = "\n"
|
||||||
|
for filename, hyp in zip(args.sound_files, hyps):
|
||||||
|
words = "".join([symbol_table[i] for i in hyp])
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
||||||
13
egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
Normal file → Executable file
13
egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
Normal file → Executable file
@ -21,7 +21,7 @@ Usage:
|
|||||||
./pruned_transducer_stateless2/pretrained.py \
|
./pruned_transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
|
||||||
--lang-dir ./data/lang_char \
|
--lang-dir ./data/lang_char \
|
||||||
--method greedy_search \
|
--decoding-method greedy_search \
|
||||||
--max-sym-per-frame 1 \
|
--max-sym-per-frame 1 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
@ -29,7 +29,7 @@ Usage:
|
|||||||
./pruned_transducer_stateless2/pretrained.py \
|
./pruned_transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
|
||||||
--lang-dir ./data/lang_char \
|
--lang-dir ./data/lang_char \
|
||||||
--method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
@ -37,7 +37,7 @@ Usage:
|
|||||||
./pruned_transducer_stateless2/pretrained.py \
|
./pruned_transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
|
||||||
--lang-dir ./data/lang_char \
|
--lang-dir ./data/lang_char \
|
||||||
--method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
--max-states 8 \
|
--max-states 8 \
|
||||||
@ -116,7 +116,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sample-rate",
|
"--sample-rate",
|
||||||
type=int,
|
type=int,
|
||||||
default=48000,
|
default=16000,
|
||||||
help="The sample rate of the input sound file",
|
help="The sample rate of the input sound file",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -124,7 +124,8 @@ def get_parser():
|
|||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="Used only when --method is beam_search and modified_beam_search ",
|
help="""Used only when --decoding-method is beam_search
|
||||||
|
and modified_beam_search """,
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -166,7 +167,7 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="""Maximum number of symbols per frame. Used only when
|
help="""Maximum number of symbols per frame. Used only when
|
||||||
--method is greedy_search.
|
--decoding-method is greedy_search.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
|
||||||
Loading…
x
Reference in New Issue
Block a user