mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Return probs in audio tagging onnx models (#1586)
This commit is contained in:
parent
fa5d861af0
commit
ba5b2e854b
@ -164,7 +164,7 @@ class OnnxAudioTagger(nn.Module):
|
||||
A 1-D tensor of shape (N,). Its dtype is torch.int64
|
||||
Returns:
|
||||
Return a tensor containing:
|
||||
- logits, A 2-D tensor of shape (N, num_classes)
|
||||
- probs, A 2-D tensor of shape (N, num_classes)
|
||||
|
||||
"""
|
||||
x, x_lens = self.encoder_embed(x, x_lens)
|
||||
@ -177,7 +177,8 @@ class OnnxAudioTagger(nn.Module):
|
||||
# Note that this is slightly different from model.py for better
|
||||
# support of onnx
|
||||
logits = logits.mean(dim=1)
|
||||
return logits
|
||||
probs = logits.sigmoid()
|
||||
return probs
|
||||
|
||||
|
||||
def export_audio_tagging_model_onnx(
|
||||
@ -220,15 +221,16 @@ def export_audio_tagging_model_onnx(
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"logits": {0: "N"},
|
||||
"probs": {0: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"model_type": "zipformer2_at",
|
||||
"model_type": "zipformer2",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"comment": "zipformer2 audio tagger",
|
||||
"url": "https://github.com/k2-fsa/icefall/tree/master/egs/audioset/AT/zipformer",
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
|
@ -20,17 +20,17 @@ This script loads ONNX models and uses them to decode waves.
|
||||
|
||||
Usage of this script:
|
||||
|
||||
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
|
||||
repo_url=https://huggingface.co/k2-fsa/sherpa-onnx-zipformer-audio-tagging-2024-04-09
|
||||
repo=$(basename $repo_url)
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo/exp
|
||||
git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
popd
|
||||
|
||||
for m in model.onnx model.int8.onnx; do
|
||||
python3 zipformer/onnx_pretrained.py \
|
||||
--model-filename $repo/exp/model.onnx \
|
||||
--label-dict $repo/data/class_labels_indices.csv \
|
||||
--model-filename $repo/model.onnx \
|
||||
--label-dict $repo/class_labels_indices.csv \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/2.wav \
|
||||
$repo/test_wavs/3.wav \
|
||||
@ -125,7 +125,7 @@ class OnnxModel:
|
||||
A 2-D tensor of shape (N,). Its dtype is torch.int64
|
||||
Returns:
|
||||
Return a Tensor:
|
||||
- logits, its shape is (N, num_classes)
|
||||
- probs, its shape is (N, num_classes)
|
||||
"""
|
||||
out = self.model.run(
|
||||
[
|
||||
@ -208,13 +208,14 @@ def main():
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
||||
logits = model(features, feature_lengths)
|
||||
probs = model(features, feature_lengths)
|
||||
|
||||
for filename, logit in zip(args.sound_files, logits):
|
||||
topk_prob, topk_index = logit.sigmoid().topk(5)
|
||||
for filename, prob in zip(args.sound_files, probs):
|
||||
topk_prob, topk_index = prob.topk(5)
|
||||
topk_labels = [label_dict[index.item()] for index in topk_index]
|
||||
logging.info(
|
||||
f"{filename}: Top 5 predicted labels are {topk_labels} with probability of {topk_prob.tolist()}"
|
||||
f"{filename}: Top 5 predicted labels are {topk_labels} with "
|
||||
f"probability of {topk_prob.tolist()}"
|
||||
)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
@ -11,6 +11,7 @@ dill
|
||||
onnx>=1.15.0
|
||||
onnxruntime>=1.16.3
|
||||
onnxoptimizer
|
||||
onnxsim
|
||||
|
||||
# style check session:
|
||||
black==22.3.0
|
||||
|
Loading…
x
Reference in New Issue
Block a user