Return probs in audio tagging onnx models (#1586)

This commit is contained in:
Fangjun Kuang 2024-04-10 09:03:30 +08:00 committed by GitHub
parent fa5d861af0
commit ba5b2e854b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 14 deletions

View File

@ -164,7 +164,7 @@ class OnnxAudioTagger(nn.Module):
A 1-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a tensor containing:
- logits, A 2-D tensor of shape (N, num_classes)
- probs, A 2-D tensor of shape (N, num_classes)
"""
x, x_lens = self.encoder_embed(x, x_lens)
@ -177,7 +177,8 @@ class OnnxAudioTagger(nn.Module):
# Note that this is slightly different from model.py for better
# support of onnx
logits = logits.mean(dim=1)
return logits
probs = logits.sigmoid()
return probs
def export_audio_tagging_model_onnx(
@ -220,15 +221,16 @@ def export_audio_tagging_model_onnx(
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"logits": {0: "N"},
"probs": {0: "N"},
},
)
meta_data = {
"model_type": "zipformer2_at",
"model_type": "zipformer2",
"version": "1",
"model_author": "k2-fsa",
"comment": "zipformer2 audio tagger",
"url": "https://github.com/k2-fsa/icefall/tree/master/egs/audioset/AT/zipformer",
}
logging.info(f"meta_data: {meta_data}")

View File

@ -20,17 +20,17 @@ This script loads ONNX models and uses them to decode waves.
Usage of this script:
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
repo_url=https://huggingface.co/k2-fsa/sherpa-onnx-zipformer-audio-tagging-2024-04-09
repo=$(basename $repo_url)
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo/exp
git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
popd
for m in model.onnx model.int8.onnx; do
python3 zipformer/onnx_pretrained.py \
--model-filename $repo/exp/model.onnx \
--label-dict $repo/data/class_labels_indices.csv \
--model-filename $repo/model.onnx \
--label-dict $repo/class_labels_indices.csv \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav \
$repo/test_wavs/3.wav \
@ -125,7 +125,7 @@ class OnnxModel:
A 2-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a Tensor:
- logits, its shape is (N, num_classes)
- probs, its shape is (N, num_classes)
"""
out = self.model.run(
[
@ -208,13 +208,14 @@ def main():
)
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
logits = model(features, feature_lengths)
probs = model(features, feature_lengths)
for filename, logit in zip(args.sound_files, logits):
topk_prob, topk_index = logit.sigmoid().topk(5)
for filename, prob in zip(args.sound_files, probs):
topk_prob, topk_index = prob.topk(5)
topk_labels = [label_dict[index.item()] for index in topk_index]
logging.info(
f"{filename}: Top 5 predicted labels are {topk_labels} with probability of {topk_prob.tolist()}"
f"{filename}: Top 5 predicted labels are {topk_labels} with "
f"probability of {topk_prob.tolist()}"
)
logging.info("Decoding Done")

View File

@ -11,6 +11,7 @@ dill
onnx>=1.15.0
onnxruntime>=1.16.3
onnxoptimizer
onnxsim
# style check session:
black==22.3.0