From f3e8e42265639156c7ded570726f68b621caf159 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Sun, 7 Apr 2024 15:30:36 +0800 Subject: [PATCH] fix style --- egs/audioset/AT/zipformer/at_datamodule.py | 1 - egs/audioset/AT/zipformer/evaluate.py | 2 -- egs/audioset/AT/zipformer/export-onnx.py | 23 ++++++++++---------- egs/audioset/AT/zipformer/onnx_pretrained.py | 4 ++-- 4 files changed, 13 insertions(+), 17 deletions(-) diff --git a/egs/audioset/AT/zipformer/at_datamodule.py b/egs/audioset/AT/zipformer/at_datamodule.py index 3b18976ee..66497c1ca 100644 --- a/egs/audioset/AT/zipformer/at_datamodule.py +++ b/egs/audioset/AT/zipformer/at_datamodule.py @@ -17,7 +17,6 @@ import argparse import inspect import logging -import pickle from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional diff --git a/egs/audioset/AT/zipformer/evaluate.py b/egs/audioset/AT/zipformer/evaluate.py index 487c0f901..b52a284d0 100644 --- a/egs/audioset/AT/zipformer/evaluate.py +++ b/egs/audioset/AT/zipformer/evaluate.py @@ -160,8 +160,6 @@ def decode_dataset( model: nn.Module, ) -> Dict: num_cuts = 0 - embedding_dict = {} - teacher_embedding_dict = {} try: num_batches = len(dl) diff --git a/egs/audioset/AT/zipformer/export-onnx.py b/egs/audioset/AT/zipformer/export-onnx.py index 5fc98f8b6..24bd431fc 100755 --- a/egs/audioset/AT/zipformer/export-onnx.py +++ b/egs/audioset/AT/zipformer/export-onnx.py @@ -62,7 +62,7 @@ use the exported ONNX models. import argparse import logging from pathlib import Path -from typing import Dict, Tuple +from typing import Dict import k2 import onnx @@ -189,9 +189,9 @@ class OnnxAudioTagger(nn.Module): x_lens: A 1-D tensor of shape (N,). Its dtype is torch.int64 Returns: - Return a tuple containing: - - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) - - encoder_out_lens, A 1-D tensor of shape (N,) + Return a tensor containing: + - logits, A 2-D tensor of shape (N, num_classes) + """ x, x_lens = self.encoder_embed(x, x_lens) src_key_padding_mask = make_pad_mask(x_lens) @@ -200,13 +200,12 @@ class OnnxAudioTagger(nn.Module): encoder_out = encoder_out.permute(1, 0, 2) # (N,T,C) logits = self.classifier(encoder_out) # (N, T, num_classes) - padding_mask = make_pad_mask(encoder_out_lens) - logits[padding_mask] = 0 - logits = logits.sum(dim=1) # mask the padding frames - logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as( - logits - ) # normalize the logits - print(logits.shape) + # Note that this is slightly different from model.py for better + # support of onnx + N = logits.shape[0] + for i in range(N): + logits[i, encoder_out_lens[i] :] = 0 + logits = logits.sum(dim=1) / encoder_out_lens.unsqueeze(-1) return logits @@ -237,7 +236,7 @@ def export_audio_tagging_model_onnx( x = torch.zeros(1, 200, 80, dtype=torch.float32) x_lens = torch.tensor([200], dtype=torch.int64) - model = torch.jit.script(model) + model = torch.jit.trace(model, (x, x_lens)) torch.onnx.export( model, diff --git a/egs/audioset/AT/zipformer/onnx_pretrained.py b/egs/audioset/AT/zipformer/onnx_pretrained.py index 156b177e5..c7753715a 100755 --- a/egs/audioset/AT/zipformer/onnx_pretrained.py +++ b/egs/audioset/AT/zipformer/onnx_pretrained.py @@ -135,7 +135,6 @@ class OnnxModel: meta = self.model.get_modelmeta().custom_metadata_map print(meta) - def __call__( self, x: torch.Tensor, @@ -162,6 +161,7 @@ class OnnxModel: ) return torch.from_numpy(out[0]) + def read_sound_files( filenames: List[str], expected_sample_rate: float ) -> List[torch.Tensor]: @@ -232,7 +232,7 @@ def main(): feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) logits = model(features, feature_lengths) - + for filename, logit in zip(args.sound_files, logits): topk_prob, topk_index = logit.sigmoid().topk(5) topk_labels = [label_dict[index.item()] for index in topk_index]