mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 22:15:28 +00:00
Support exporting to ONNX for the wenetspeech recipe (#615)
* Support exporting to ONNX for the wenetspeech recipe
This commit is contained in:
parent
aa58c2ee02
commit
c39cba5191
124
.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh
vendored
Executable file
124
.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh
vendored
Executable file
@ -0,0 +1,124 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/wenetspeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git lfs install
|
||||
git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
soxi $repo/test_wavs/*.wav
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
pushd $repo/exp
|
||||
ln -s pretrained_epoch_10_avg_2.pt pretrained.pt
|
||||
ln -s pretrained_epoch_10_avg_2.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
log "Test exporting to ONNX format"
|
||||
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--lang-dir $repo/data/lang_char \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--onnx 1
|
||||
|
||||
log "Export to torchscript model"
|
||||
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--lang-dir $repo/data/lang_char \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--jit 1
|
||||
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--lang-dir $repo/data/lang_char \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--jit-trace 1
|
||||
|
||||
ls -lh $repo/exp/*.onnx
|
||||
ls -lh $repo/exp/*.pt
|
||||
|
||||
log "Decode with ONNX models"
|
||||
|
||||
./pruned_transducer_stateless2/onnx_check.py \
|
||||
--jit-filename $repo/exp/cpu_jit.pt \
|
||||
--onnx-encoder-filename $repo/exp/encoder.onnx \
|
||||
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
||||
--onnx-joiner-filename $repo/exp/joiner.onnx \
|
||||
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \
|
||||
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx
|
||||
|
||||
./pruned_transducer_stateless2/onnx_pretrained.py \
|
||||
--tokens $repo/data/lang_char/tokens.txt \
|
||||
--encoder-model-filename $repo/exp/encoder.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner.onnx \
|
||||
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
|
||||
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
|
||||
$repo/test_wavs/DEV_T0000000000.wav \
|
||||
$repo/test_wavs/DEV_T0000000001.wav \
|
||||
$repo/test_wavs/DEV_T0000000002.wav
|
||||
|
||||
log "Decode with models exported by torch.jit.trace()"
|
||||
|
||||
./pruned_transducer_stateless2/jit_pretrained.py \
|
||||
--tokens $repo/data/lang_char/tokens.txt \
|
||||
--encoder-model-filename $repo/exp/encoder_jit_trace.pt \
|
||||
--decoder-model-filename $repo/exp/decoder_jit_trace.pt \
|
||||
--joiner-model-filename $repo/exp/joiner_jit_trace.pt \
|
||||
$repo/test_wavs/DEV_T0000000000.wav \
|
||||
$repo/test_wavs/DEV_T0000000001.wav \
|
||||
$repo/test_wavs/DEV_T0000000002.wav
|
||||
|
||||
./pruned_transducer_stateless2/jit_pretrained.py \
|
||||
--tokens $repo/data/lang_char/tokens.txt \
|
||||
--encoder-model-filename $repo/exp/encoder_jit_script.pt \
|
||||
--decoder-model-filename $repo/exp/decoder_jit_script.pt \
|
||||
--joiner-model-filename $repo/exp/joiner_jit_script.pt \
|
||||
$repo/test_wavs/DEV_T0000000000.wav \
|
||||
$repo/test_wavs/DEV_T0000000001.wav \
|
||||
$repo/test_wavs/DEV_T0000000002.wav
|
||||
|
||||
for sym in 1 2 3; do
|
||||
log "Greedy search with --max-sym-per-frame $sym"
|
||||
|
||||
./pruned_transducer_stateless2/pretrained.py \
|
||||
--checkpoint $repo/exp/epoch-99.pt \
|
||||
--lang-dir $repo/data/lang_char \
|
||||
--decoding-method greedy_search \
|
||||
--max-sym-per-frame $sym \
|
||||
$repo/test_wavs/DEV_T0000000000.wav \
|
||||
$repo/test_wavs/DEV_T0000000001.wav \
|
||||
$repo/test_wavs/DEV_T0000000002.wav
|
||||
done
|
||||
|
||||
for method in modified_beam_search beam_search fast_beam_search; do
|
||||
log "$method"
|
||||
|
||||
./pruned_transducer_stateless2/pretrained.py \
|
||||
--decoding-method $method \
|
||||
--beam-size 4 \
|
||||
--checkpoint $repo/exp/epoch-99.pt \
|
||||
--lang-dir $repo/data/lang_char \
|
||||
$repo/test_wavs/DEV_T0000000000.wav \
|
||||
$repo/test_wavs/DEV_T0000000001.wav \
|
||||
$repo/test_wavs/DEV_T0000000002.wav
|
||||
done
|
||||
80
.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml
vendored
Normal file
80
.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml
vendored
Normal file
@ -0,0 +1,80 @@
|
||||
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
|
||||
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: run-wenetspeech-pruned-transducer-stateless2
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
|
||||
jobs:
|
||||
run_librispeech_pruned_transducer_stateless3_2022_05_13:
|
||||
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'wenetspeech'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'pip'
|
||||
cache-dependency-path: '**/requirements-ci.txt'
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||
pip uninstall -y protobuf
|
||||
pip install --no-binary protobuf protobuf
|
||||
|
||||
- name: Cache kaldifeat
|
||||
id: my-cache
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/install-kaldifeat.sh
|
||||
|
||||
- name: Inference with pre-trained model
|
||||
shell: bash
|
||||
env:
|
||||
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||
run: |
|
||||
sudo apt-get -qq install git-lfs tree sox
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
|
||||
.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh
|
||||
@ -73,6 +73,10 @@ Check `onnx_check.py` for how to use them.
|
||||
|
||||
Please see ./onnx_pretrained.py for usage of the generated files
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa-onnx
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(4) Export `model.state_dict()`
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
|
||||
@ -31,6 +31,8 @@ Usage of this script:
|
||||
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \
|
||||
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \
|
||||
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \
|
||||
--joiner-encoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_encoder_proj.onnx \
|
||||
--joiner-decoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_decoder_proj.onnx \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
484
egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
Normal file → Executable file
484
egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
Normal file → Executable file
@ -1,3 +1,4 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
@ -18,6 +19,64 @@
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
(1) Export to torchscript model using torch.jit.script()
|
||||
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--epoch 10 \
|
||||
--avg 2 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
||||
load it by `torch.jit.load("cpu_jit.pt")`.
|
||||
|
||||
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
||||
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/python/offline_asr/conformer/index.html
|
||||
for how to use `cpu_jit.pt` for speech recognition.
|
||||
|
||||
It will also generate 3 other files: `encoder_jit_script.pt`,
|
||||
`decoder_jit_script.pt`, and `joiner_jit_script.pt`. Check ./jit_pretrained.py
|
||||
for how to use them.
|
||||
|
||||
(2) Export to torchscript model using torch.jit.trace()
|
||||
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--epoch 10 \
|
||||
--avg 2 \
|
||||
--jit-trace 1
|
||||
|
||||
It will generate the following files:
|
||||
- encoder_jit_trace.pt
|
||||
- decoder_jit_trace.pt
|
||||
- joiner_jit_trace.pt
|
||||
|
||||
Check ./jit_pretrained.py for usage.
|
||||
|
||||
(3) Export to ONNX format
|
||||
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--epoch 10 \
|
||||
--avg 2 \
|
||||
--onnx 1
|
||||
|
||||
Refer to ./onnx_check.py and ./onnx_pretrained.py
|
||||
for usage.
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa-onnx
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(4) Export `model.state_dict()`
|
||||
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
@ -35,10 +94,13 @@ you can do:
|
||||
cd /path/to/egs/wenetspeech/ASR
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--epoch 10 \
|
||||
--avg 2 \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 100 \
|
||||
--lang-dir data/lang_char
|
||||
|
||||
You can find pretrained models at
|
||||
https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/tree/main/exp
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -46,6 +108,8 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
@ -96,6 +160,44 @@ def get_parser():
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
It will generate 4 files:
|
||||
- encoder_jit_script.pt
|
||||
- decoder_jit_script.pt
|
||||
- joiner_jit_script.pt
|
||||
- cpu_jit.pt (which combines the above 3 files)
|
||||
|
||||
Check ./jit_pretrained.py for how to use xxx_jit_script.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit-trace",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.trace.
|
||||
It will generate 3 files:
|
||||
- encoder_jit_trace.pt
|
||||
- decoder_jit_trace.pt
|
||||
- joiner_jit_trace.pt
|
||||
|
||||
Check ./jit_pretrained.py for how to use them.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""If True, --jit is ignored and it exports the model
|
||||
to onnx format. It will generate the following files:
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
- joiner_encoder_proj.onnx
|
||||
- joiner_decoder_proj.onnx
|
||||
|
||||
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -110,6 +212,332 @@ def get_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def export_encoder_model_jit_script(
|
||||
encoder_model: nn.Module,
|
||||
encoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given encoder model with torch.jit.script()
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
script_model = torch.jit.script(encoder_model)
|
||||
script_model.save(encoder_filename)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_jit_script(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given decoder model with torch.jit.script()
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The input decoder model
|
||||
decoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
script_model = torch.jit.script(decoder_model)
|
||||
script_model.save(decoder_filename)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_jit_script(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
) -> None:
|
||||
"""Export the given joiner model with torch.jit.trace()
|
||||
|
||||
Args:
|
||||
joiner_model:
|
||||
The input joiner model
|
||||
joiner_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
script_model = torch.jit.script(joiner_model)
|
||||
script_model.save(joiner_filename)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
def export_encoder_model_jit_trace(
|
||||
encoder_model: nn.Module,
|
||||
encoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given encoder model with torch.jit.trace()
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||
|
||||
traced_model = torch.jit.trace(encoder_model, (x, x_lens))
|
||||
traced_model.save(encoder_filename)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_jit_trace(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given decoder model with torch.jit.trace()
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The input decoder model
|
||||
decoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = torch.tensor([False])
|
||||
|
||||
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||
traced_model.save(decoder_filename)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_jit_trace(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
) -> None:
|
||||
"""Export the given joiner model with torch.jit.trace()
|
||||
|
||||
Note: The argument project_input is fixed to True. A user should not
|
||||
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||
will do that for the user.
|
||||
|
||||
Args:
|
||||
joiner_model:
|
||||
The input joiner model
|
||||
joiner_filename:
|
||||
The filename to save the exported model.
|
||||
|
||||
"""
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
||||
traced_model.save(joiner_filename)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: nn.Module,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given encoder model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
|
||||
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||
|
||||
and it has two outputs:
|
||||
|
||||
- encoder_out, a tensor of shape (N, T, C)
|
||||
- encoder_out_lens, a tensor of shape (N,)
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||
|
||||
# encoder_model = torch.jit.script(encoder_model)
|
||||
# It throws the following error for the above statement
|
||||
#
|
||||
# RuntimeError: Exporting the operator __is_ to ONNX opset version
|
||||
# 11 is not supported. Please feel free to request support or
|
||||
# submit a pull request on PyTorch GitHub.
|
||||
#
|
||||
# I cannot find which statement causes the above error.
|
||||
# torch.onnx.export() will use torch.jit.trace() internally, which
|
||||
# works well for the current reworked model
|
||||
warmup = 1.0
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, x_lens, warmup),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_lens", "warmup"],
|
||||
output_names=["encoder_out", "encoder_out_lens"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"encoder_out": {0: "N", 1: "T"},
|
||||
"encoder_out_lens": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = False # Always False, so we can use torch.jit.trace() here
|
||||
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
|
||||
# in this case
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
(y, need_pad),
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y", "need_pad"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
|
||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
|
||||
The exported encoder_proj model has one input:
|
||||
|
||||
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
The exported decoder_proj model has one input:
|
||||
|
||||
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
encoder_proj_filename = str(joiner_filename).replace(
|
||||
".onnx", "_encoder_proj.onnx"
|
||||
)
|
||||
|
||||
decoder_proj_filename = str(joiner_filename).replace(
|
||||
".onnx", "_decoder_proj.onnx"
|
||||
)
|
||||
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
|
||||
|
||||
projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
|
||||
projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
|
||||
|
||||
project_input = False
|
||||
# Note: It uses torch.jit.trace() internally
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(projected_encoder_out, projected_decoder_out, project_input),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"projected_encoder_out",
|
||||
"projected_decoder_out",
|
||||
"project_input",
|
||||
],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"projected_encoder_out": {0: "N"},
|
||||
"projected_decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
torch.onnx.export(
|
||||
joiner_model.encoder_proj,
|
||||
encoder_out,
|
||||
encoder_proj_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["encoder_out"],
|
||||
output_names=["projected_encoder_out"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"projected_encoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {encoder_proj_filename}")
|
||||
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
torch.onnx.export(
|
||||
joiner_model.decoder_proj,
|
||||
decoder_out,
|
||||
decoder_proj_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["decoder_out"],
|
||||
output_names=["projected_decoder_out"],
|
||||
dynamic_axes={
|
||||
"decoder_out": {0: "N"},
|
||||
"projected_decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_proj_filename}")
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
@ -147,22 +575,66 @@ def main():
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
model.eval()
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
if params.onnx is True:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
opset_version = 11
|
||||
logging.info("Exporting to onnx format")
|
||||
encoder_filename = params.exp_dir / "encoder.onnx"
|
||||
export_encoder_model_onnx(
|
||||
model.encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||
export_decoder_model_onnx(
|
||||
model.decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||
export_joiner_model_onnx(
|
||||
model.joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
elif params.jit:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
logging.info("Using torch.jit.script")
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
# Also export encoder/decoder/joiner separately
|
||||
encoder_filename = params.exp_dir / "encoder_jit_script.pt"
|
||||
export_encoder_model_jit_script(model.encoder, encoder_filename)
|
||||
|
||||
decoder_filename = params.exp_dir / "decoder_jit_script.pt"
|
||||
export_decoder_model_jit_script(model.decoder, decoder_filename)
|
||||
|
||||
joiner_filename = params.exp_dir / "joiner_jit_script.pt"
|
||||
export_joiner_model_jit_script(model.joiner, joiner_filename)
|
||||
elif params.jit_trace is True:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
logging.info("Using torch.jit.trace()")
|
||||
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
||||
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||
|
||||
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
|
||||
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||
|
||||
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
|
||||
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||
else:
|
||||
logging.info("Not using torch.jit.script")
|
||||
# Save it using a format so that it can be loaded
|
||||
|
||||
339
egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
Executable file
339
egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
Executable file
@ -0,0 +1,339 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads torchscript models, either exported by `torch.jit.trace()`
|
||||
or by `torch.jit.script()`, and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--tokens data/lang_char/tokens.txt \
|
||||
--epoch 10 \
|
||||
--avg 2 \
|
||||
--jit-trace 1
|
||||
|
||||
or
|
||||
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--tokens data/lang_char/tokens.txt \
|
||||
--epoch 10 \
|
||||
--avg 2 \
|
||||
--jit 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./pruned_transducer_stateless2/jit_pretrained.py \
|
||||
--encoder-model-filename ./pruned_transducer_stateless2/exp/encoder_jit_trace.pt \
|
||||
--decoder-model-filename ./pruned_transducer_stateless2/exp/decoder_jit_trace.pt \
|
||||
--joiner-model-filename ./pruned_transducer_stateless2/exp/joiner_jit_trace.pt \
|
||||
--tokens data/lang_char/tokens.txt \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
or
|
||||
|
||||
./pruned_transducer_stateless2/jit_pretrained.py \
|
||||
--encoder-model-filename ./pruned_transducer_stateless2/exp/encoder_jit_script.pt \
|
||||
--decoder-model-filename ./pruned_transducer_stateless2/exp/decoder_jit_script.pt \
|
||||
--joiner-model-filename ./pruned_transducer_stateless2/exp/joiner_jit_script.pt \
|
||||
--tokens data/lang_char/tokens.txt \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
You can find pretrained models at
|
||||
https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/tree/main/exp
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Context size of the decoder model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
decoder: torch.jit.ScriptModule,
|
||||
joiner: torch.jit.ScriptModule,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
context_size: int,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
decoder:
|
||||
The decoder model.
|
||||
joiner:
|
||||
The joiner model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,).
|
||||
context_size:
|
||||
The context size of the decoder model.
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
device = encoder_out.device
|
||||
blank_id = 0 # hard-code to 0
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
).squeeze(1)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = packed_encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
decoder_out = decoder_out[:batch_size]
|
||||
|
||||
logits = joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
)
|
||||
decoder_out = decoder_out.squeeze(1)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
encoder = torch.jit.load(args.encoder_model_filename)
|
||||
decoder = torch.jit.load(args.decoder_model_filename)
|
||||
joiner = torch.jit.load(args.joiner_model_filename)
|
||||
|
||||
encoder.eval()
|
||||
decoder.eval()
|
||||
joiner.eval()
|
||||
|
||||
encoder.to(device)
|
||||
decoder.to(device)
|
||||
joiner.to(device)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = args.sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
expected_sample_rate=args.sample_rate,
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = encoder(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
)
|
||||
|
||||
hyps = greedy_search(
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
context_size=args.context_size,
|
||||
)
|
||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = "".join([symbol_table[i] for i in hyp])
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
307
egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
Executable file
307
egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
Executable file
@ -0,0 +1,307 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script checks that exported onnx models produce the same output
|
||||
with the given torchscript model for the same input.
|
||||
|
||||
Usage:
|
||||
|
||||
./pruned_transducer_stateless2/onnx_check.py \
|
||||
--jit-filename ./t/cpu_jit.pt \
|
||||
--onnx-encoder-filename ./t/encoder.onnx \
|
||||
--onnx-decoder-filename ./t/decoder.onnx \
|
||||
--onnx-joiner-filename ./t/joiner.onnx \
|
||||
--onnx-joiner-encoder-proj-filename ./t/joiner_encoder_proj.onnx \
|
||||
--onnx-joiner-decoder-proj-filename ./t/joiner_decoder_proj.onnx
|
||||
|
||||
You can generate cpu_jit.pt, encoder.onnx, decoder.onnx, and other
|
||||
xxx.onnx files using ./export.py
|
||||
|
||||
We provide pretrained models at:
|
||||
https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/tree/main/exp
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
ort.set_default_logger_severity(3)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the torchscript model exported by torch.jit.script",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-encoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-decoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-joiner-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx joiner model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-joiner-encoder-proj-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx joiner encoder projection model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-joiner-decoder-proj-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx joiner decoder projection model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def test_encoder(
|
||||
model: torch.jit.ScriptModule,
|
||||
encoder_session: ort.InferenceSession,
|
||||
):
|
||||
inputs = encoder_session.get_inputs()
|
||||
outputs = encoder_session.get_outputs()
|
||||
input_names = [n.name for n in inputs]
|
||||
output_names = [n.name for n in outputs]
|
||||
|
||||
assert inputs[0].shape == ["N", "T", 80]
|
||||
assert inputs[1].shape == ["N"]
|
||||
|
||||
for N in [1, 5]:
|
||||
for T in [12, 25]:
|
||||
print("N, T", N, T)
|
||||
x = torch.rand(N, T, 80, dtype=torch.float32)
|
||||
x_lens = torch.randint(low=10, high=T + 1, size=(N,))
|
||||
x_lens[0] = T
|
||||
|
||||
encoder_inputs = {
|
||||
input_names[0]: x.numpy(),
|
||||
input_names[1]: x_lens.numpy(),
|
||||
}
|
||||
encoder_out, encoder_out_lens = encoder_session.run(
|
||||
output_names,
|
||||
encoder_inputs,
|
||||
)
|
||||
|
||||
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
|
||||
|
||||
encoder_out = torch.from_numpy(encoder_out)
|
||||
assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
|
||||
(encoder_out - torch_encoder_out).abs().max(),
|
||||
encoder_out.shape,
|
||||
torch_encoder_out.shape,
|
||||
)
|
||||
|
||||
|
||||
def test_decoder(
|
||||
model: torch.jit.ScriptModule,
|
||||
decoder_session: ort.InferenceSession,
|
||||
):
|
||||
inputs = decoder_session.get_inputs()
|
||||
outputs = decoder_session.get_outputs()
|
||||
input_names = [n.name for n in inputs]
|
||||
output_names = [n.name for n in outputs]
|
||||
|
||||
assert inputs[0].shape == ["N", 2]
|
||||
for N in [1, 5, 10]:
|
||||
y = torch.randint(low=1, high=500, size=(10, 2))
|
||||
|
||||
decoder_inputs = {input_names[0]: y.numpy()}
|
||||
decoder_out = decoder_session.run(
|
||||
output_names,
|
||||
decoder_inputs,
|
||||
)[0]
|
||||
decoder_out = torch.from_numpy(decoder_out)
|
||||
|
||||
torch_decoder_out = model.decoder(y, need_pad=False)
|
||||
assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), (
|
||||
(decoder_out - torch_decoder_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
def test_joiner(
|
||||
model: torch.jit.ScriptModule,
|
||||
joiner_session: ort.InferenceSession,
|
||||
joiner_encoder_proj_session: ort.InferenceSession,
|
||||
joiner_decoder_proj_session: ort.InferenceSession,
|
||||
):
|
||||
joiner_inputs = joiner_session.get_inputs()
|
||||
joiner_outputs = joiner_session.get_outputs()
|
||||
joiner_input_names = [n.name for n in joiner_inputs]
|
||||
joiner_output_names = [n.name for n in joiner_outputs]
|
||||
|
||||
assert joiner_inputs[0].shape == ["N", 512]
|
||||
assert joiner_inputs[1].shape == ["N", 512]
|
||||
|
||||
joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs()
|
||||
encoder_proj_input_name = joiner_encoder_proj_inputs[0].name
|
||||
|
||||
assert joiner_encoder_proj_inputs[0].shape == ["N", 512]
|
||||
|
||||
joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs()
|
||||
encoder_proj_output_name = joiner_encoder_proj_outputs[0].name
|
||||
|
||||
joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs()
|
||||
decoder_proj_input_name = joiner_decoder_proj_inputs[0].name
|
||||
|
||||
assert joiner_decoder_proj_inputs[0].shape == ["N", 512]
|
||||
|
||||
joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs()
|
||||
decoder_proj_output_name = joiner_decoder_proj_outputs[0].name
|
||||
|
||||
for N in [1, 5, 10]:
|
||||
encoder_out = torch.rand(N, 512)
|
||||
decoder_out = torch.rand(N, 512)
|
||||
|
||||
projected_encoder_out = torch.rand(N, 512)
|
||||
projected_decoder_out = torch.rand(N, 512)
|
||||
|
||||
joiner_inputs = {
|
||||
joiner_input_names[0]: projected_encoder_out.numpy(),
|
||||
joiner_input_names[1]: projected_decoder_out.numpy(),
|
||||
}
|
||||
joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0]
|
||||
joiner_out = torch.from_numpy(joiner_out)
|
||||
|
||||
torch_joiner_out = model.joiner(
|
||||
projected_encoder_out,
|
||||
projected_decoder_out,
|
||||
project_input=False,
|
||||
)
|
||||
assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), (
|
||||
(joiner_out - torch_joiner_out).abs().max()
|
||||
)
|
||||
|
||||
# Now test encoder_proj
|
||||
joiner_encoder_proj_inputs = {
|
||||
encoder_proj_input_name: encoder_out.numpy()
|
||||
}
|
||||
joiner_encoder_proj_out = joiner_encoder_proj_session.run(
|
||||
[encoder_proj_output_name], joiner_encoder_proj_inputs
|
||||
)[0]
|
||||
joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out)
|
||||
|
||||
torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out)
|
||||
assert torch.allclose(
|
||||
joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
|
||||
), (
|
||||
(joiner_encoder_proj_out - torch_joiner_encoder_proj_out)
|
||||
.abs()
|
||||
.max()
|
||||
)
|
||||
|
||||
# Now test decoder_proj
|
||||
joiner_decoder_proj_inputs = {
|
||||
decoder_proj_input_name: decoder_out.numpy()
|
||||
}
|
||||
joiner_decoder_proj_out = joiner_decoder_proj_session.run(
|
||||
[decoder_proj_output_name], joiner_decoder_proj_inputs
|
||||
)[0]
|
||||
joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out)
|
||||
|
||||
torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out)
|
||||
assert torch.allclose(
|
||||
joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
|
||||
), (
|
||||
(joiner_decoder_proj_out - torch_joiner_decoder_proj_out)
|
||||
.abs()
|
||||
.max()
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
model = torch.jit.load(args.jit_filename)
|
||||
|
||||
options = ort.SessionOptions()
|
||||
options.inter_op_num_threads = 1
|
||||
options.intra_op_num_threads = 1
|
||||
|
||||
logging.info("Test encoder")
|
||||
encoder_session = ort.InferenceSession(
|
||||
args.onnx_encoder_filename,
|
||||
sess_options=options,
|
||||
)
|
||||
test_encoder(model, encoder_session)
|
||||
|
||||
logging.info("Test decoder")
|
||||
decoder_session = ort.InferenceSession(
|
||||
args.onnx_decoder_filename,
|
||||
sess_options=options,
|
||||
)
|
||||
test_decoder(model, decoder_session)
|
||||
|
||||
logging.info("Test joiner")
|
||||
joiner_session = ort.InferenceSession(
|
||||
args.onnx_joiner_filename,
|
||||
sess_options=options,
|
||||
)
|
||||
joiner_encoder_proj_session = ort.InferenceSession(
|
||||
args.onnx_joiner_encoder_proj_filename,
|
||||
sess_options=options,
|
||||
)
|
||||
joiner_decoder_proj_session = ort.InferenceSession(
|
||||
args.onnx_joiner_decoder_proj_filename,
|
||||
sess_options=options,
|
||||
)
|
||||
test_joiner(
|
||||
model,
|
||||
joiner_session,
|
||||
joiner_encoder_proj_session,
|
||||
joiner_decoder_proj_session,
|
||||
)
|
||||
logging.info("Finished checking ONNX models")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20220727)
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
391
egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
Executable file
391
egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
Executable file
@ -0,0 +1,391 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads ONNX models and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--onnx 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./pruned_transducer_stateless3/onnx_pretrained.py \
|
||||
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \
|
||||
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \
|
||||
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \
|
||||
--joiner-encoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_encoder_proj.onnx \
|
||||
--joiner-decoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_decoder_proj.onnx \
|
||||
--tokens data/lang_char/tokens.txt \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
We provide pretrained models at:
|
||||
https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/tree/main/exp
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-encoder-proj-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner encoder_proj onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-decoder-proj-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner decoder_proj onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Context size of the decoder model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
decoder: ort.InferenceSession,
|
||||
joiner: ort.InferenceSession,
|
||||
joiner_encoder_proj: ort.InferenceSession,
|
||||
joiner_decoder_proj: ort.InferenceSession,
|
||||
encoder_out: np.ndarray,
|
||||
encoder_out_lens: np.ndarray,
|
||||
context_size: int,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
decoder:
|
||||
The decoder model.
|
||||
joiner:
|
||||
The joiner model.
|
||||
joiner_encoder_proj:
|
||||
The joiner encoder projection model.
|
||||
joiner_decoder_proj:
|
||||
The joiner decoder projection model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,).
|
||||
context_size:
|
||||
The context size of the decoder model.
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
encoder_out = torch.from_numpy(encoder_out)
|
||||
encoder_out_lens = torch.from_numpy(encoder_out_lens)
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
projected_encoder_out = joiner_encoder_proj.run(
|
||||
[joiner_encoder_proj.get_outputs()[0].name],
|
||||
{
|
||||
joiner_encoder_proj.get_inputs()[
|
||||
0
|
||||
].name: packed_encoder_out.data.numpy()
|
||||
},
|
||||
)[0]
|
||||
|
||||
blank_id = 0 # hard-code to 0
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input_nodes = decoder.get_inputs()
|
||||
decoder_output_nodes = decoder.get_outputs()
|
||||
|
||||
joiner_input_nodes = joiner.get_inputs()
|
||||
joiner_output_nodes = joiner.get_outputs()
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = decoder.run(
|
||||
[decoder_output_nodes[0].name],
|
||||
{
|
||||
decoder_input_nodes[0].name: decoder_input.numpy(),
|
||||
},
|
||||
)[0].squeeze(1)
|
||||
projected_decoder_out = joiner_decoder_proj.run(
|
||||
[joiner_decoder_proj.get_outputs()[0].name],
|
||||
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
|
||||
)[0]
|
||||
|
||||
projected_decoder_out = torch.from_numpy(projected_decoder_out)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = projected_encoder_out[start:end]
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
projected_decoder_out = projected_decoder_out[:batch_size]
|
||||
|
||||
logits = joiner.run(
|
||||
[joiner_output_nodes[0].name],
|
||||
{
|
||||
joiner_input_nodes[0].name: current_encoder_out,
|
||||
joiner_input_nodes[1].name: projected_decoder_out.numpy(),
|
||||
},
|
||||
)[0]
|
||||
logits = torch.from_numpy(logits).squeeze(1).squeeze(1)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = decoder.run(
|
||||
[decoder_output_nodes[0].name],
|
||||
{
|
||||
decoder_input_nodes[0].name: decoder_input.numpy(),
|
||||
},
|
||||
)[0].squeeze(1)
|
||||
projected_decoder_out = joiner_decoder_proj.run(
|
||||
[joiner_decoder_proj.get_outputs()[0].name],
|
||||
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
|
||||
)[0]
|
||||
projected_decoder_out = torch.from_numpy(projected_decoder_out)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
encoder = ort.InferenceSession(
|
||||
args.encoder_model_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
decoder = ort.InferenceSession(
|
||||
args.decoder_model_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
joiner = ort.InferenceSession(
|
||||
args.joiner_model_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
joiner_encoder_proj = ort.InferenceSession(
|
||||
args.joiner_encoder_proj_model_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
joiner_decoder_proj = ort.InferenceSession(
|
||||
args.joiner_decoder_proj_model_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = args.sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
expected_sample_rate=args.sample_rate,
|
||||
)
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
||||
|
||||
encoder_input_nodes = encoder.get_inputs()
|
||||
encoder_out_nodes = encoder.get_outputs()
|
||||
encoder_out, encoder_out_lens = encoder.run(
|
||||
[encoder_out_nodes[0].name, encoder_out_nodes[1].name],
|
||||
{
|
||||
encoder_input_nodes[0].name: features.numpy(),
|
||||
encoder_input_nodes[1].name: feature_lengths.numpy(),
|
||||
},
|
||||
)
|
||||
|
||||
hyps = greedy_search(
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
joiner_encoder_proj=joiner_encoder_proj,
|
||||
joiner_decoder_proj=joiner_decoder_proj,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
context_size=args.context_size,
|
||||
)
|
||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = "".join([symbol_table[i] for i in hyp])
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
13
egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
Normal file → Executable file
13
egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
Normal file → Executable file
@ -21,7 +21,7 @@ Usage:
|
||||
./pruned_transducer_stateless2/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
|
||||
--lang-dir ./data/lang_char \
|
||||
--method greedy_search \
|
||||
--decoding-method greedy_search \
|
||||
--max-sym-per-frame 1 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
@ -29,7 +29,7 @@ Usage:
|
||||
./pruned_transducer_stateless2/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
|
||||
--lang-dir ./data/lang_char \
|
||||
--method modified_beam_search \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
@ -37,7 +37,7 @@ Usage:
|
||||
./pruned_transducer_stateless2/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
|
||||
--lang-dir ./data/lang_char \
|
||||
--method fast_beam_search \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8 \
|
||||
@ -116,7 +116,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=48000,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
@ -124,7 +124,8 @@ def get_parser():
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Used only when --method is beam_search and modified_beam_search ",
|
||||
help="""Used only when --decoding-method is beam_search
|
||||
and modified_beam_search """,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -166,7 +167,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame. Used only when
|
||||
--method is greedy_search.
|
||||
--decoding-method is greedy_search.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
|
||||
Loading…
x
Reference in New Issue
Block a user