mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Return probs in audio tagging onnx models (#1586)
This commit is contained in:
parent
fa5d861af0
commit
ba5b2e854b
@ -164,7 +164,7 @@ class OnnxAudioTagger(nn.Module):
|
|||||||
A 1-D tensor of shape (N,). Its dtype is torch.int64
|
A 1-D tensor of shape (N,). Its dtype is torch.int64
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor containing:
|
Return a tensor containing:
|
||||||
- logits, A 2-D tensor of shape (N, num_classes)
|
- probs, A 2-D tensor of shape (N, num_classes)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
x, x_lens = self.encoder_embed(x, x_lens)
|
x, x_lens = self.encoder_embed(x, x_lens)
|
||||||
@ -177,7 +177,8 @@ class OnnxAudioTagger(nn.Module):
|
|||||||
# Note that this is slightly different from model.py for better
|
# Note that this is slightly different from model.py for better
|
||||||
# support of onnx
|
# support of onnx
|
||||||
logits = logits.mean(dim=1)
|
logits = logits.mean(dim=1)
|
||||||
return logits
|
probs = logits.sigmoid()
|
||||||
|
return probs
|
||||||
|
|
||||||
|
|
||||||
def export_audio_tagging_model_onnx(
|
def export_audio_tagging_model_onnx(
|
||||||
@ -220,15 +221,16 @@ def export_audio_tagging_model_onnx(
|
|||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"x": {0: "N", 1: "T"},
|
"x": {0: "N", 1: "T"},
|
||||||
"x_lens": {0: "N"},
|
"x_lens": {0: "N"},
|
||||||
"logits": {0: "N"},
|
"probs": {0: "N"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
meta_data = {
|
meta_data = {
|
||||||
"model_type": "zipformer2_at",
|
"model_type": "zipformer2",
|
||||||
"version": "1",
|
"version": "1",
|
||||||
"model_author": "k2-fsa",
|
"model_author": "k2-fsa",
|
||||||
"comment": "zipformer2 audio tagger",
|
"comment": "zipformer2 audio tagger",
|
||||||
|
"url": "https://github.com/k2-fsa/icefall/tree/master/egs/audioset/AT/zipformer",
|
||||||
}
|
}
|
||||||
logging.info(f"meta_data: {meta_data}")
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
|
@ -20,17 +20,17 @@ This script loads ONNX models and uses them to decode waves.
|
|||||||
|
|
||||||
Usage of this script:
|
Usage of this script:
|
||||||
|
|
||||||
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
|
repo_url=https://huggingface.co/k2-fsa/sherpa-onnx-zipformer-audio-tagging-2024-04-09
|
||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
git clone $repo_url
|
||||||
pushd $repo/exp
|
pushd $repo
|
||||||
git lfs pull --include "*.onnx"
|
git lfs pull --include "*.onnx"
|
||||||
popd
|
popd
|
||||||
|
|
||||||
for m in model.onnx model.int8.onnx; do
|
for m in model.onnx model.int8.onnx; do
|
||||||
python3 zipformer/onnx_pretrained.py \
|
python3 zipformer/onnx_pretrained.py \
|
||||||
--model-filename $repo/exp/model.onnx \
|
--model-filename $repo/model.onnx \
|
||||||
--label-dict $repo/data/class_labels_indices.csv \
|
--label-dict $repo/class_labels_indices.csv \
|
||||||
$repo/test_wavs/1.wav \
|
$repo/test_wavs/1.wav \
|
||||||
$repo/test_wavs/2.wav \
|
$repo/test_wavs/2.wav \
|
||||||
$repo/test_wavs/3.wav \
|
$repo/test_wavs/3.wav \
|
||||||
@ -125,7 +125,7 @@ class OnnxModel:
|
|||||||
A 2-D tensor of shape (N,). Its dtype is torch.int64
|
A 2-D tensor of shape (N,). Its dtype is torch.int64
|
||||||
Returns:
|
Returns:
|
||||||
Return a Tensor:
|
Return a Tensor:
|
||||||
- logits, its shape is (N, num_classes)
|
- probs, its shape is (N, num_classes)
|
||||||
"""
|
"""
|
||||||
out = self.model.run(
|
out = self.model.run(
|
||||||
[
|
[
|
||||||
@ -208,13 +208,14 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
||||||
logits = model(features, feature_lengths)
|
probs = model(features, feature_lengths)
|
||||||
|
|
||||||
for filename, logit in zip(args.sound_files, logits):
|
for filename, prob in zip(args.sound_files, probs):
|
||||||
topk_prob, topk_index = logit.sigmoid().topk(5)
|
topk_prob, topk_index = prob.topk(5)
|
||||||
topk_labels = [label_dict[index.item()] for index in topk_index]
|
topk_labels = [label_dict[index.item()] for index in topk_index]
|
||||||
logging.info(
|
logging.info(
|
||||||
f"{filename}: Top 5 predicted labels are {topk_labels} with probability of {topk_prob.tolist()}"
|
f"{filename}: Top 5 predicted labels are {topk_labels} with "
|
||||||
|
f"probability of {topk_prob.tolist()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
@ -11,6 +11,7 @@ dill
|
|||||||
onnx>=1.15.0
|
onnx>=1.15.0
|
||||||
onnxruntime>=1.16.3
|
onnxruntime>=1.16.3
|
||||||
onnxoptimizer
|
onnxoptimizer
|
||||||
|
onnxsim
|
||||||
|
|
||||||
# style check session:
|
# style check session:
|
||||||
black==22.3.0
|
black==22.3.0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user