decode multiple files

This commit is contained in:
Fangjun Kuang 2024-04-09 13:45:25 +08:00
parent 9c08d97416
commit 7792e69f0d
3 changed files with 71 additions and 27 deletions

View File

@ -16,20 +16,67 @@ function test_pretrained() {
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo/exp pushd $repo/exp
git lfs pull --include pretrained.pt git lfs pull --include pretrained.pt
popd ln -s pretrained.pt epoch-99.pt
ls -lh ls -lh
popd
log "test pretrained.pt" log "test pretrained.pt"
for w in 1.wav 2.wav 3.wav; do python3 zipformer/pretrained.py \
log "test $w" --checkpoint $repo/exp/pretrained.pt \
python3 zipformer/pretrained.py \ --label-dict $repo/data/class_labels_indices.csv \
--checkpoint $repo/exp/pretrained.pt \ $repo/test_wavs/1.wav \
--label-dict $repo/data/class_labels_indices.csv \ $repo/test_wavs/2.wav \
$repo/test_wavs/$w $repo/test_wavs/3.wav \
done $repo/test_wavs/4.wav
log "test jit export"
ls -lh $repo/exp/
python3 zipformer/export.py \
--exp-dir $repo/exp \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
--jit 1
ls -lh $repo/exp/
log "test jit models"
python3 zipformer/jit_pretrained.py \
--nn-model-filename $repo/exp/jit_script.pt \
--label-dict $repo/data/class_labels_indices.csv \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav \
$repo/test_wavs/3.wav \
$repo/test_wavs/4.wav
log "test onnx export"
ls -lh $repo/exp/
python3 zipformer/export-onnx.py \
--exp-dir $repo/exp \
--epoch 99 \
--avg 1 \
--use-averaged-model 0
ls -lh $repo/exp/
pushd $repo/exp/
mv model-epoch-99-avg-1.onnx model.onnx
mv model-epoch-99-avg-1.int8.onnx model.int8.onnx
popd
ls -lh $repo/exp/
log "test onnx models"
for m in model.onnx model.int8.onnx; do
log "$m"
python3 zipformer/onnx_pretrained.py \
--model-filename $repo/exp/model.onnx \
--label-dict $repo/data/class_labels_indices.csv \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav \
$repo/test_wavs/3.wav \
$repo/test_wavs/4.wav
done
} }
test_pretrained test_pretrained

View File

@ -25,7 +25,7 @@
Usage: Usage:
Note: This is a example for librispeech dataset, if you are using different Note: This is an example for AudioSet dataset, if you are using different
dataset, you should change the argument values according to your dataset. dataset, you should change the argument values according to your dataset.
(1) Export to torchscript model using torch.jit.script() (1) Export to torchscript model using torch.jit.script()

View File

@ -18,27 +18,22 @@
This script loads a checkpoint and uses it to decode waves. This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command: You can generate the checkpoint with the following command:
Note: This is a example for librispeech dataset, if you are using different Note: This is an example for the AudioSet dataset, if you are using different
dataset, you should change the argument values according to your dataset. dataset, you should change the argument values according to your dataset.
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
Usage of this script: Usage of this script:
./zipformer/pretrained.py \ repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
--checkpoint ./zipformer/exp/pretrained.pt \ repo=$(basename $repo_url)
/path/to/foo.wav \ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
/path/to/bar.wav pushd $repo/exp
git lfs pull --include pretrained.pt
popd
python3 zipformer/pretrained.py \
You can also use `./zipformer/exp/epoch-xx.pt`. --checkpoint $repo/exp/pretrained.pt \
--label-dict $repo/data/class_labels_indices.csv \
Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py $repo/test_wavs/1.wav
""" """
@ -189,7 +184,8 @@ def main():
topk_prob, topk_index = logit.sigmoid().topk(5) topk_prob, topk_index = logit.sigmoid().topk(5)
topk_labels = [label_dict[index.item()] for index in topk_index] topk_labels = [label_dict[index.item()] for index in topk_index]
logging.info( logging.info(
f"{filename}: Top 5 predicted labels are {topk_labels} with probability of {topk_prob.tolist()}" f"{filename}: Top 5 predicted labels are {topk_labels} with "
f"probability of {topk_prob.tolist()}"
) )
logging.info("Done") logging.info("Done")
@ -199,4 +195,5 @@ if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()