diff --git a/egs/audioset/AT/zipformer/export-onnx.py b/egs/audioset/AT/zipformer/export-onnx.py index 24bd431fc..af83c0e9c 100755 --- a/egs/audioset/AT/zipformer/export-onnx.py +++ b/egs/audioset/AT/zipformer/export-onnx.py @@ -202,10 +202,7 @@ class OnnxAudioTagger(nn.Module): logits = self.classifier(encoder_out) # (N, T, num_classes) # Note that this is slightly different from model.py for better # support of onnx - N = logits.shape[0] - for i in range(N): - logits[i, encoder_out_lens[i] :] = 0 - logits = logits.sum(dim=1) / encoder_out_lens.unsqueeze(-1) + logits = logits.mean(dim=1) return logits