small fixes

This commit is contained in:
Fangjun Kuang 2024-04-09 14:42:41 +08:00
parent 7792e69f0d
commit cff243eb47
5 changed files with 69 additions and 93 deletions

View File

@ -6,56 +6,28 @@
""" """
This script exports a transducer model from PyTorch to ONNX. This script exports a transducer model from PyTorch to ONNX.
We use the pre-trained model from Usage of this script:
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
as an example to show how to use this file.
1. Download the pre-trained model repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
cd egs/librispeech/ASR
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12#/
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url) repo=$(basename $repo_url)
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo pushd $repo/exp
git lfs pull --include "exp/pretrained.pt" git lfs pull --include pretrained.pt
cd exp
ln -s pretrained.pt epoch-99.pt ln -s pretrained.pt epoch-99.pt
popd popd
2. Export the model to ONNX python3 zipformer/export-onnx.py \
--exp-dir $repo/exp \
./zipformer/export-onnx.py \
--use-averaged-model 0 \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
--exp-dir $repo/exp \ --use-averaged-model 0
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal False \
--chunk-size "16,32,64,-1" \
--left-context-frames "64,128,256,-1"
It will generate the following 3 files inside $repo/exp: pushd $repo/exp
mv model-epoch-99-avg-1.onnx model.onnx
mv model-epoch-99-avg-1.int8.onnx model.int8.onnx
popd
- encoder-epoch-99-avg-1.onnx See ./onnx_pretrained.py
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
See ./onnx_pretrained.py and ./onnx_check.py for how to
use the exported ONNX models. use the exported ONNX models.
""" """
@ -261,6 +233,16 @@ def export_audio_tagging_model_onnx(
add_meta_data(filename=filename, meta_data=meta_data) add_meta_data(filename=filename, meta_data=meta_data)
def optimize_model(filename):
# see https://github.com/microsoft/onnxruntime/issues/1899#issuecomment-534806537
from onnx import optimizer
passes = ["extract_constant_to_initializer", "eliminate_unused_initializer"]
onnx_model = onnx.load(filename)
optimized_model = optimizer.optimize(onnx_model, passes)
onnx.save(optimized_model, filename)
@torch.no_grad() @torch.no_grad()
def main(): def main():
args = get_parser().parse_args() args = get_parser().parse_args()
@ -389,6 +371,7 @@ def main():
model_filename, model_filename,
opset_version=opset_version, opset_version=opset_version,
) )
optimized_model(model_filename)
logging.info(f"Exported audio tagging model to {model_filename}") logging.info(f"Exported audio tagging model to {model_filename}")
# Generate int8 quantization models # Generate int8 quantization models
@ -403,6 +386,7 @@ def main():
op_types_to_quantize=["MatMul"], op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8, weight_type=QuantType.QInt8,
) )
optimized_model(model_filename_int8)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -42,6 +42,7 @@ load it by `torch.jit.load("jit_script.pt")`.
Check ./jit_pretrained.py for its usage. Check ./jit_pretrained.py for its usage.
Check https://github.com/k2-fsa/sherpa Check https://github.com/k2-fsa/sherpa
and https://github.com/k2-fsa/sherpa-onnx
for how to use the exported models outside of icefall. for how to use the exported models outside of icefall.
(2) Export `model.state_dict()` (2) Export `model.state_dict()`
@ -55,13 +56,13 @@ for how to use the exported models outside of icefall.
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`. load it by `icefall.checkpoint.load_checkpoint()`.
To use the generated file with `zipformer/decode.py`, To use the generated file with `zipformer/evaluate.py`,
you can do: you can do:
cd /path/to/exp_dir cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR cd /path/to/egs/audioset/AT
./zipformer/evaluate.py \ ./zipformer/evaluate.py \
--exp-dir ./zipformer/exp \ --exp-dir ./zipformer/exp \
--use-averaged-model False \ --use-averaged-model False \

View File

@ -28,10 +28,20 @@ You can use the following command to get the exported models:
Usage of this script: Usage of this script:
./zipformer/jit_pretrained.py \ repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
--nn-model-filename ./zipformer/exp/cpu_jit.pt \ repo=$(basename $repo_url)
/path/to/foo.wav \ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
/path/to/bar.wav pushd $repo/exp
git lfs pull --include pretrained.pt
popd
python3 zipformer/jit_pretrained.py \
--nn-model-filename $repo/exp/jit_script.pt \
--label-dict $repo/data/class_labels_indices.csv \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav \
$repo/test_wavs/3.wav \
$repo/test_wavs/4.wav
""" """
import argparse import argparse
@ -168,7 +178,8 @@ def main():
topk_prob, topk_index = logit.sigmoid().topk(5) topk_prob, topk_index = logit.sigmoid().topk(5)
topk_labels = [label_dict[index.item()] for index in topk_index] topk_labels = [label_dict[index.item()] for index in topk_index]
logging.info( logging.info(
f"{filename}: Top 5 predicted labels are {topk_labels} with probability of {topk_prob.tolist()}" f"{filename}: Top 5 predicted labels are {topk_labels} with "
f"probability of {topk_prob.tolist()}"
) )
logging.info("Done") logging.info("Done")

View File

@ -17,48 +17,25 @@
# limitations under the License. # limitations under the License.
""" """
This script loads ONNX models and uses them to decode waves. This script loads ONNX models and uses them to decode waves.
You can use the following command to get the exported models:
We use the pre-trained model from Usage of this script:
https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12#/
as an example to show how to use this file.
1. Download the pre-trained model repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
cd egs/librispeech/ASR
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12#/
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url) repo=$(basename $repo_url)
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo pushd $repo/exp
git lfs pull --include "exp/pretrained.pt" git lfs pull --include "*.onnx"
cd exp
ln -s pretrained.pt epoch-99.pt
popd popd
2. Export the model to ONNX for m in model.onnx model.int8.onnx; do
python3 zipformer/onnx_pretrained.py \
./zipformer/export-onnx.py \ --model-filename $repo/exp/model.onnx \
--use-averaged-model 0 \ --label-dict $repo/data/class_labels_indices.csv \
--epoch 99 \ $repo/test_wavs/1.wav \
--avg 1 \ $repo/test_wavs/2.wav \
--exp-dir $repo/exp \ $repo/test_wavs/3.wav \
--causal False $repo/test_wavs/4.wav
done
It will generate the following 3 files inside $repo/exp:
- model-epoch-99-avg-1.onnx
3. Run this file
./zipformer/onnx_pretrained.py \
--model-filename $repo/exp/model-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
""" """
import argparse import argparse

View File

@ -33,7 +33,10 @@ Usage of this script:
python3 zipformer/pretrained.py \ python3 zipformer/pretrained.py \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--label-dict $repo/data/class_labels_indices.csv \ --label-dict $repo/data/class_labels_indices.csv \
$repo/test_wavs/1.wav $repo/test_wavs/1.wav \
$repo/test_wavs/2.wav \
$repo/test_wavs/3.wav \
$repo/test_wavs/4.wav
""" """