support onnx export with batch size 1; also works for batch processing, but the results might be affected by the padding

This commit is contained in:
marcoyang 2024-04-07 15:45:28 +08:00
parent f3e8e42265
commit 01b744f127

View File

@ -202,10 +202,7 @@ class OnnxAudioTagger(nn.Module):
logits = self.classifier(encoder_out) # (N, T, num_classes)
# Note that this is slightly different from model.py for better
# support of onnx
N = logits.shape[0]
for i in range(N):
logits[i, encoder_out_lens[i] :] = 0
logits = logits.sum(dim=1) / encoder_out_lens.unsqueeze(-1)
logits = logits.mean(dim=1)
return logits