fix style

This commit is contained in:
marcoyang 2024-04-07 15:30:36 +08:00
parent 686d2d9787
commit f3e8e42265
4 changed files with 13 additions and 17 deletions

View File

@ -17,7 +17,6 @@
import argparse import argparse
import inspect import inspect
import logging import logging
import pickle
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional

View File

@ -160,8 +160,6 @@ def decode_dataset(
model: nn.Module, model: nn.Module,
) -> Dict: ) -> Dict:
num_cuts = 0 num_cuts = 0
embedding_dict = {}
teacher_embedding_dict = {}
try: try:
num_batches = len(dl) num_batches = len(dl)

View File

@ -62,7 +62,7 @@ use the exported ONNX models.
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Dict, Tuple from typing import Dict
import k2 import k2
import onnx import onnx
@ -189,9 +189,9 @@ class OnnxAudioTagger(nn.Module):
x_lens: x_lens:
A 1-D tensor of shape (N,). Its dtype is torch.int64 A 1-D tensor of shape (N,). Its dtype is torch.int64
Returns: Returns:
Return a tuple containing: Return a tensor containing:
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim) - logits, A 2-D tensor of shape (N, num_classes)
- encoder_out_lens, A 1-D tensor of shape (N,)
""" """
x, x_lens = self.encoder_embed(x, x_lens) x, x_lens = self.encoder_embed(x, x_lens)
src_key_padding_mask = make_pad_mask(x_lens) src_key_padding_mask = make_pad_mask(x_lens)
@ -200,13 +200,12 @@ class OnnxAudioTagger(nn.Module):
encoder_out = encoder_out.permute(1, 0, 2) # (N,T,C) encoder_out = encoder_out.permute(1, 0, 2) # (N,T,C)
logits = self.classifier(encoder_out) # (N, T, num_classes) logits = self.classifier(encoder_out) # (N, T, num_classes)
padding_mask = make_pad_mask(encoder_out_lens) # Note that this is slightly different from model.py for better
logits[padding_mask] = 0 # support of onnx
logits = logits.sum(dim=1) # mask the padding frames N = logits.shape[0]
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as( for i in range(N):
logits logits[i, encoder_out_lens[i] :] = 0
) # normalize the logits logits = logits.sum(dim=1) / encoder_out_lens.unsqueeze(-1)
print(logits.shape)
return logits return logits
@ -237,7 +236,7 @@ def export_audio_tagging_model_onnx(
x = torch.zeros(1, 200, 80, dtype=torch.float32) x = torch.zeros(1, 200, 80, dtype=torch.float32)
x_lens = torch.tensor([200], dtype=torch.int64) x_lens = torch.tensor([200], dtype=torch.int64)
model = torch.jit.script(model) model = torch.jit.trace(model, (x, x_lens))
torch.onnx.export( torch.onnx.export(
model, model,

View File

@ -135,7 +135,6 @@ class OnnxModel:
meta = self.model.get_modelmeta().custom_metadata_map meta = self.model.get_modelmeta().custom_metadata_map
print(meta) print(meta)
def __call__( def __call__(
self, self,
x: torch.Tensor, x: torch.Tensor,
@ -162,6 +161,7 @@ class OnnxModel:
) )
return torch.from_numpy(out[0]) return torch.from_numpy(out[0])
def read_sound_files( def read_sound_files(
filenames: List[str], expected_sample_rate: float filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
@ -232,7 +232,7 @@ def main():
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
logits = model(features, feature_lengths) logits = model(features, feature_lengths)
for filename, logit in zip(args.sound_files, logits): for filename, logit in zip(args.sound_files, logits):
topk_prob, topk_index = logit.sigmoid().topk(5) topk_prob, topk_index = logit.sigmoid().topk(5)
topk_labels = [label_dict[index.item()] for index in topk_index] topk_labels = [label_dict[index.item()] for index in topk_index]