fix style

This commit is contained in:
marcoyang 2024-04-07 15:30:36 +08:00
parent 686d2d9787
commit f3e8e42265
4 changed files with 13 additions and 17 deletions

View File

@ -17,7 +17,6 @@
import argparse
import inspect
import logging
import pickle
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional

View File

@ -160,8 +160,6 @@ def decode_dataset(
model: nn.Module,
) -> Dict:
num_cuts = 0
embedding_dict = {}
teacher_embedding_dict = {}
try:
num_batches = len(dl)

View File

@ -62,7 +62,7 @@ use the exported ONNX models.
import argparse
import logging
from pathlib import Path
from typing import Dict, Tuple
from typing import Dict
import k2
import onnx
@ -189,9 +189,9 @@ class OnnxAudioTagger(nn.Module):
x_lens:
A 1-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a tuple containing:
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
- encoder_out_lens, A 1-D tensor of shape (N,)
Return a tensor containing:
- logits, A 2-D tensor of shape (N, num_classes)
"""
x, x_lens = self.encoder_embed(x, x_lens)
src_key_padding_mask = make_pad_mask(x_lens)
@ -200,13 +200,12 @@ class OnnxAudioTagger(nn.Module):
encoder_out = encoder_out.permute(1, 0, 2) # (N,T,C)
logits = self.classifier(encoder_out) # (N, T, num_classes)
padding_mask = make_pad_mask(encoder_out_lens)
logits[padding_mask] = 0
logits = logits.sum(dim=1) # mask the padding frames
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
logits
) # normalize the logits
print(logits.shape)
# Note that this is slightly different from model.py for better
# support of onnx
N = logits.shape[0]
for i in range(N):
logits[i, encoder_out_lens[i] :] = 0
logits = logits.sum(dim=1) / encoder_out_lens.unsqueeze(-1)
return logits
@ -237,7 +236,7 @@ def export_audio_tagging_model_onnx(
x = torch.zeros(1, 200, 80, dtype=torch.float32)
x_lens = torch.tensor([200], dtype=torch.int64)
model = torch.jit.script(model)
model = torch.jit.trace(model, (x, x_lens))
torch.onnx.export(
model,

View File

@ -135,7 +135,6 @@ class OnnxModel:
meta = self.model.get_modelmeta().custom_metadata_map
print(meta)
def __call__(
self,
x: torch.Tensor,
@ -162,6 +161,7 @@ class OnnxModel:
)
return torch.from_numpy(out[0])
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
@ -232,7 +232,7 @@ def main():
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
logits = model(features, feature_lengths)
for filename, logit in zip(args.sound_files, logits):
topk_prob, topk_index = logit.sigmoid().topk(5)
topk_labels = [label_dict[index.item()] for index in topk_index]