mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
small fixes
This commit is contained in:
parent
7792e69f0d
commit
cff243eb47
@ -6,56 +6,28 @@
|
|||||||
"""
|
"""
|
||||||
This script exports a transducer model from PyTorch to ONNX.
|
This script exports a transducer model from PyTorch to ONNX.
|
||||||
|
|
||||||
We use the pre-trained model from
|
Usage of this script:
|
||||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
|
||||||
as an example to show how to use this file.
|
|
||||||
|
|
||||||
1. Download the pre-trained model
|
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
pushd $repo/exp
|
||||||
|
git lfs pull --include pretrained.pt
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
cd egs/librispeech/ASR
|
python3 zipformer/export-onnx.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--use-averaged-model 0
|
||||||
|
|
||||||
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12#/
|
pushd $repo/exp
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
mv model-epoch-99-avg-1.onnx model.onnx
|
||||||
repo=$(basename $repo_url)
|
mv model-epoch-99-avg-1.int8.onnx model.int8.onnx
|
||||||
|
popd
|
||||||
|
|
||||||
pushd $repo
|
See ./onnx_pretrained.py
|
||||||
git lfs pull --include "exp/pretrained.pt"
|
|
||||||
|
|
||||||
cd exp
|
|
||||||
ln -s pretrained.pt epoch-99.pt
|
|
||||||
popd
|
|
||||||
|
|
||||||
2. Export the model to ONNX
|
|
||||||
|
|
||||||
./zipformer/export-onnx.py \
|
|
||||||
--use-averaged-model 0 \
|
|
||||||
--epoch 99 \
|
|
||||||
--avg 1 \
|
|
||||||
--exp-dir $repo/exp \
|
|
||||||
--num-encoder-layers "2,2,3,4,3,2" \
|
|
||||||
--downsampling-factor "1,2,4,8,4,2" \
|
|
||||||
--feedforward-dim "512,768,1024,1536,1024,768" \
|
|
||||||
--num-heads "4,4,4,8,4,4" \
|
|
||||||
--encoder-dim "192,256,384,512,384,256" \
|
|
||||||
--query-head-dim 32 \
|
|
||||||
--value-head-dim 12 \
|
|
||||||
--pos-head-dim 4 \
|
|
||||||
--pos-dim 48 \
|
|
||||||
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
|
||||||
--cnn-module-kernel "31,31,15,15,15,31" \
|
|
||||||
--decoder-dim 512 \
|
|
||||||
--joiner-dim 512 \
|
|
||||||
--causal False \
|
|
||||||
--chunk-size "16,32,64,-1" \
|
|
||||||
--left-context-frames "64,128,256,-1"
|
|
||||||
|
|
||||||
It will generate the following 3 files inside $repo/exp:
|
|
||||||
|
|
||||||
- encoder-epoch-99-avg-1.onnx
|
|
||||||
- decoder-epoch-99-avg-1.onnx
|
|
||||||
- joiner-epoch-99-avg-1.onnx
|
|
||||||
|
|
||||||
See ./onnx_pretrained.py and ./onnx_check.py for how to
|
|
||||||
use the exported ONNX models.
|
use the exported ONNX models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -261,6 +233,16 @@ def export_audio_tagging_model_onnx(
|
|||||||
add_meta_data(filename=filename, meta_data=meta_data)
|
add_meta_data(filename=filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
def optimize_model(filename):
|
||||||
|
# see https://github.com/microsoft/onnxruntime/issues/1899#issuecomment-534806537
|
||||||
|
from onnx import optimizer
|
||||||
|
|
||||||
|
passes = ["extract_constant_to_initializer", "eliminate_unused_initializer"]
|
||||||
|
onnx_model = onnx.load(filename)
|
||||||
|
optimized_model = optimizer.optimize(onnx_model, passes)
|
||||||
|
onnx.save(optimized_model, filename)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
@ -389,6 +371,7 @@ def main():
|
|||||||
model_filename,
|
model_filename,
|
||||||
opset_version=opset_version,
|
opset_version=opset_version,
|
||||||
)
|
)
|
||||||
|
optimized_model(model_filename)
|
||||||
logging.info(f"Exported audio tagging model to {model_filename}")
|
logging.info(f"Exported audio tagging model to {model_filename}")
|
||||||
|
|
||||||
# Generate int8 quantization models
|
# Generate int8 quantization models
|
||||||
@ -403,6 +386,7 @@ def main():
|
|||||||
op_types_to_quantize=["MatMul"],
|
op_types_to_quantize=["MatMul"],
|
||||||
weight_type=QuantType.QInt8,
|
weight_type=QuantType.QInt8,
|
||||||
)
|
)
|
||||||
|
optimized_model(model_filename_int8)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -42,6 +42,7 @@ load it by `torch.jit.load("jit_script.pt")`.
|
|||||||
Check ./jit_pretrained.py for its usage.
|
Check ./jit_pretrained.py for its usage.
|
||||||
|
|
||||||
Check https://github.com/k2-fsa/sherpa
|
Check https://github.com/k2-fsa/sherpa
|
||||||
|
and https://github.com/k2-fsa/sherpa-onnx
|
||||||
for how to use the exported models outside of icefall.
|
for how to use the exported models outside of icefall.
|
||||||
|
|
||||||
(2) Export `model.state_dict()`
|
(2) Export `model.state_dict()`
|
||||||
@ -55,13 +56,13 @@ for how to use the exported models outside of icefall.
|
|||||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||||
|
|
||||||
To use the generated file with `zipformer/decode.py`,
|
To use the generated file with `zipformer/evaluate.py`,
|
||||||
you can do:
|
you can do:
|
||||||
|
|
||||||
cd /path/to/exp_dir
|
cd /path/to/exp_dir
|
||||||
ln -s pretrained.pt epoch-9999.pt
|
ln -s pretrained.pt epoch-9999.pt
|
||||||
|
|
||||||
cd /path/to/egs/librispeech/ASR
|
cd /path/to/egs/audioset/AT
|
||||||
./zipformer/evaluate.py \
|
./zipformer/evaluate.py \
|
||||||
--exp-dir ./zipformer/exp \
|
--exp-dir ./zipformer/exp \
|
||||||
--use-averaged-model False \
|
--use-averaged-model False \
|
||||||
|
@ -28,10 +28,20 @@ You can use the following command to get the exported models:
|
|||||||
|
|
||||||
Usage of this script:
|
Usage of this script:
|
||||||
|
|
||||||
./zipformer/jit_pretrained.py \
|
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
|
||||||
--nn-model-filename ./zipformer/exp/cpu_jit.pt \
|
repo=$(basename $repo_url)
|
||||||
/path/to/foo.wav \
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
/path/to/bar.wav
|
pushd $repo/exp
|
||||||
|
git lfs pull --include pretrained.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
python3 zipformer/jit_pretrained.py \
|
||||||
|
--nn-model-filename $repo/exp/jit_script.pt \
|
||||||
|
--label-dict $repo/data/class_labels_indices.csv \
|
||||||
|
$repo/test_wavs/1.wav \
|
||||||
|
$repo/test_wavs/2.wav \
|
||||||
|
$repo/test_wavs/3.wav \
|
||||||
|
$repo/test_wavs/4.wav
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -168,7 +178,8 @@ def main():
|
|||||||
topk_prob, topk_index = logit.sigmoid().topk(5)
|
topk_prob, topk_index = logit.sigmoid().topk(5)
|
||||||
topk_labels = [label_dict[index.item()] for index in topk_index]
|
topk_labels = [label_dict[index.item()] for index in topk_index]
|
||||||
logging.info(
|
logging.info(
|
||||||
f"{filename}: Top 5 predicted labels are {topk_labels} with probability of {topk_prob.tolist()}"
|
f"{filename}: Top 5 predicted labels are {topk_labels} with "
|
||||||
|
f"probability of {topk_prob.tolist()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Done")
|
logging.info("Done")
|
||||||
|
@ -17,48 +17,25 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
This script loads ONNX models and uses them to decode waves.
|
This script loads ONNX models and uses them to decode waves.
|
||||||
You can use the following command to get the exported models:
|
|
||||||
|
|
||||||
We use the pre-trained model from
|
Usage of this script:
|
||||||
https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12#/
|
|
||||||
as an example to show how to use this file.
|
|
||||||
|
|
||||||
1. Download the pre-trained model
|
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
pushd $repo/exp
|
||||||
|
git lfs pull --include "*.onnx"
|
||||||
|
popd
|
||||||
|
|
||||||
cd egs/librispeech/ASR
|
for m in model.onnx model.int8.onnx; do
|
||||||
|
python3 zipformer/onnx_pretrained.py \
|
||||||
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12#/
|
--model-filename $repo/exp/model.onnx \
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
--label-dict $repo/data/class_labels_indices.csv \
|
||||||
repo=$(basename $repo_url)
|
$repo/test_wavs/1.wav \
|
||||||
|
$repo/test_wavs/2.wav \
|
||||||
pushd $repo
|
$repo/test_wavs/3.wav \
|
||||||
git lfs pull --include "exp/pretrained.pt"
|
$repo/test_wavs/4.wav
|
||||||
|
done
|
||||||
cd exp
|
|
||||||
ln -s pretrained.pt epoch-99.pt
|
|
||||||
popd
|
|
||||||
|
|
||||||
2. Export the model to ONNX
|
|
||||||
|
|
||||||
./zipformer/export-onnx.py \
|
|
||||||
--use-averaged-model 0 \
|
|
||||||
--epoch 99 \
|
|
||||||
--avg 1 \
|
|
||||||
--exp-dir $repo/exp \
|
|
||||||
--causal False
|
|
||||||
|
|
||||||
It will generate the following 3 files inside $repo/exp:
|
|
||||||
|
|
||||||
- model-epoch-99-avg-1.onnx
|
|
||||||
|
|
||||||
3. Run this file
|
|
||||||
|
|
||||||
./zipformer/onnx_pretrained.py \
|
|
||||||
--model-filename $repo/exp/model-epoch-99-avg-1.onnx \
|
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -33,7 +33,10 @@ Usage of this script:
|
|||||||
python3 zipformer/pretrained.py \
|
python3 zipformer/pretrained.py \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--label-dict $repo/data/class_labels_indices.csv \
|
--label-dict $repo/data/class_labels_indices.csv \
|
||||||
$repo/test_wavs/1.wav
|
$repo/test_wavs/1.wav \
|
||||||
|
$repo/test_wavs/2.wav \
|
||||||
|
$repo/test_wavs/3.wav \
|
||||||
|
$repo/test_wavs/4.wav
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user