diff --git a/.github/scripts/audioset/AT/run.sh b/.github/scripts/audioset/AT/run.sh index 8f0f210ee..81cafd64b 100755 --- a/.github/scripts/audioset/AT/run.sh +++ b/.github/scripts/audioset/AT/run.sh @@ -16,20 +16,67 @@ function test_pretrained() { GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url pushd $repo/exp git lfs pull --include pretrained.pt - popd - + ln -s pretrained.pt epoch-99.pt ls -lh + popd log "test pretrained.pt" - for w in 1.wav 2.wav 3.wav; do - log "test $w" - python3 zipformer/pretrained.py \ - --checkpoint $repo/exp/pretrained.pt \ - --label-dict $repo/data/class_labels_indices.csv \ - $repo/test_wavs/$w - done + python3 zipformer/pretrained.py \ + --checkpoint $repo/exp/pretrained.pt \ + --label-dict $repo/data/class_labels_indices.csv \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/3.wav \ + $repo/test_wavs/4.wav + log "test jit export" + ls -lh $repo/exp/ + python3 zipformer/export.py \ + --exp-dir $repo/exp \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --jit 1 + ls -lh $repo/exp/ + + log "test jit models" + python3 zipformer/jit_pretrained.py \ + --nn-model-filename $repo/exp/jit_script.pt \ + --label-dict $repo/data/class_labels_indices.csv \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/3.wav \ + $repo/test_wavs/4.wav + + log "test onnx export" + ls -lh $repo/exp/ + python3 zipformer/export-onnx.py \ + --exp-dir $repo/exp \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 + + ls -lh $repo/exp/ + + pushd $repo/exp/ + mv model-epoch-99-avg-1.onnx model.onnx + mv model-epoch-99-avg-1.int8.onnx model.int8.onnx + popd + + ls -lh $repo/exp/ + + log "test onnx models" + for m in model.onnx model.int8.onnx; do + log "$m" + python3 zipformer/onnx_pretrained.py \ + --model-filename $repo/exp/model.onnx \ + --label-dict $repo/data/class_labels_indices.csv \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/3.wav \ + $repo/test_wavs/4.wav + done } test_pretrained diff --git a/egs/audioset/AT/zipformer/export.py b/egs/audioset/AT/zipformer/export.py index bdcf8b7dd..38bf6c8c5 100755 --- a/egs/audioset/AT/zipformer/export.py +++ b/egs/audioset/AT/zipformer/export.py @@ -25,7 +25,7 @@ Usage: -Note: This is a example for librispeech dataset, if you are using different +Note: This is an example for AudioSet dataset, if you are using different dataset, you should change the argument values according to your dataset. (1) Export to torchscript model using torch.jit.script() diff --git a/egs/audioset/AT/zipformer/pretrained.py b/egs/audioset/AT/zipformer/pretrained.py index 60e4d0518..460e8b2ed 100755 --- a/egs/audioset/AT/zipformer/pretrained.py +++ b/egs/audioset/AT/zipformer/pretrained.py @@ -18,27 +18,22 @@ This script loads a checkpoint and uses it to decode waves. You can generate the checkpoint with the following command: -Note: This is a example for librispeech dataset, if you are using different +Note: This is an example for the AudioSet dataset, if you are using different dataset, you should change the argument values according to your dataset. - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - Usage of this script: -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - /path/to/foo.wav \ - /path/to/bar.wav + repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12 + repo=$(basename $repo_url) + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + pushd $repo/exp + git lfs pull --include pretrained.pt + popd - -You can also use `./zipformer/exp/epoch-xx.pt`. - -Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py + python3 zipformer/pretrained.py \ + --checkpoint $repo/exp/pretrained.pt \ + --label-dict $repo/data/class_labels_indices.csv \ + $repo/test_wavs/1.wav """ @@ -189,7 +184,8 @@ def main(): topk_prob, topk_index = logit.sigmoid().topk(5) topk_labels = [label_dict[index.item()] for index in topk_index] logging.info( - f"{filename}: Top 5 predicted labels are {topk_labels} with probability of {topk_prob.tolist()}" + f"{filename}: Top 5 predicted labels are {topk_labels} with " + f"probability of {topk_prob.tolist()}" ) logging.info("Done") @@ -199,4 +195,5 @@ if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) + main()